diff --git a/LICENSE b/LICENSE index f1732fb47afc0..3c667bf45059a 100644 --- a/LICENSE +++ b/LICENSE @@ -754,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -771,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index c5936b5038ac9..badd85ed48c82 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -39,6 +39,8 @@ $(function() { var column = "table ." + $(this).attr("name"); $(column).hide(); }); + // Stripe table rows after rows have been hidden to ensure correct striping. + stripeTables(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 32187ba6e8df0..6bb03015abb51 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -28,8 +28,3 @@ function stripeTables() { }); }); } - -/* Stripe all tables after pages finish loading. */ -$(function() { - stripeTables(); -}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a2220e761ac98..db57712c83503 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,6 +120,20 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + span.expand-additional-metrics { cursor: pointer; } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c11f1db0064fd..ef93009a074e7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Lower and upper bounds on the number of executors. These are required. private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) - verifyBounds() // How long there must be backlogged tasks for before an addition is triggered private val schedulerBacklogTimeout = conf.getLong( @@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) // How long an executor must be idle for before it is removed - private val removeThresholdSeconds = conf.getLong( + private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + validateSettings() + // Number of executors to add in the next round private var numExecutorsToAdd = 1 @@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Polling loop interval (ms) private val intervalMillis: Long = 100 - // Whether we are testing this class. This should only be used internally. - private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) - // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock /** - * Verify that the lower and upper bounds on the number of executors are valid. + * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ - private def verifyBounds(): Unit = { + private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") } @@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } /** @@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging val removeRequestAcknowledged = testing || sc.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging private def onExecutorIdle(executorId: String): Unit = synchronized { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") - removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 } } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..dbff9d12b5ad7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afead..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b4db783979ec..03ea672c813d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,9 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable @@ -41,6 +40,7 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.executor.TriggerThreadDump import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.JobProgressListener -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. @@ -361,6 +363,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { @@ -535,6 +560,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** + * :: Experimental :: + * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -577,6 +604,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { } /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e2f13accdfab5..e7454beddbfd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -276,7 +276,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,8 +285,9 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661f..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 202fba699ab26..af5fd8e0ac00c 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -69,11 +69,13 @@ case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, - reduceId: Int) + reduceId: Int, + message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + + s"message=\n$message\n)" } } @@ -81,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e3aeba7e6c39d..5c6e8d32c5c8a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,11 +21,6 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import java.io.DataInputStream - -import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} - import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,6 +28,7 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -286,6 +282,8 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** + * :: Experimental :: + * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -312,15 +310,19 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ + @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ + @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349eac..5ba66178e2b78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61b125ef7c6c1..45beb8fc8c925 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,13 +21,13 @@ import java.io._ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import org.apache.spark.input.PortableDataStream + import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 -import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -397,22 +397,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) @@ -442,7 +453,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -468,7 +479,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -494,7 +505,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -537,7 +548,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -563,7 +574,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -746,104 +757,6 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e0992..a4153aaa926f8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - initialize() - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index e9ca9166eb4d6..c0cbd28a845be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 0000000000000..88118e2837741 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val blockHandler = new ExternalShuffleBlockHandler() + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f1f66d0903f1c..ca262de832e25 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -111,6 +111,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -154,6 +157,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -419,6 +423,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d41..3711824a40cfc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e24a15f015e1c..caf4d76713d49 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -86,12 +86,17 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -131,6 +136,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -155,7 +161,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -200,7 +206,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization @@ -257,7 +263,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 5164a74bec4e9..36a1e5d475f46 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -115,7 +115,7 @@ private[spark] class FixedLengthBinaryRecordReader if (currentPosition < splitEnd) { // setup a buffer to store the record val buffer = recordValue.getBytes - fileInputStream.read(buffer, 0, recordLength) + fileInputStream.readFully(buffer) // update our current position currentPosition = currentPosition + recordLength // return true diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1c4327cf13b51..b937ea825f49e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,15 +17,17 @@ package org.apache.spark.network.netty +import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -33,18 +35,30 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) - clientFactory = transportContext.createClientFactory() + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) } @@ -57,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { listener: BlockFetchingListener): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811d..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d65..f198aa8564a54 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 946fb5616d3ec..a157e36e2286e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -211,20 +211,11 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) @@ -234,6 +225,18 @@ class HadoopRDD[K, V]( if (bytesReadCallback.isDefined) { context.taskMetrics.inputMetrics = Some(inputMetrics) } + + var reader: RecordReader[K, V] = null + val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener{ context => closeIfNeeded() } + val key: K = reader.createKey() + val value: V = reader.createValue() + var recordsSinceMetricsUpdate = 0 override def getNext() = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6d6b86721ca74..351e145f96f9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -107,20 +107,10 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) @@ -131,6 +121,18 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index af17b5d5d2571..22449517d100f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1053,7 +1053,7 @@ class DAGScheduler( logInfo("Resubmitted " + task + ", so marking it as still running") stage.pendingTasks += task - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) @@ -1063,7 +1063,7 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure")) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } @@ -1094,7 +1094,7 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 54904bffdf10b..4e3d9de540783 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -215,7 +215,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d8c0e2f66df01..5289661eb896b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() @@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8e2faff90f9b2..c5f3493477bc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() @@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 71c08e9d5a8c3..be184464e0ae9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.util.Utils /** * Failed to fetch a shuffle block. The executor catches this exception and propagates it @@ -30,13 +31,22 @@ private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, - reduceId: Int) - extends Exception { + reduceId: Int, + message: String, + cause: Throwable = null) + extends Exception(message, cause) { - override def getMessage: String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** @@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId) { - - override def getMessage: String = message -} + extends FetchFailedException(null, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index f49917b7fe833..e3e7434df45b0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer @@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Some(block) => { + case Success(block) => { block.asInstanceOf[Iterator[T]] } - case None => { + case Failure(e) => { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") + "Failed to get block " + blockId + ", which is not a shuffle block", e) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5f5dd0dc1c63f..e48d7772d6ee9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -40,7 +40,6 @@ import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -57,6 +56,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -66,11 +71,10 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -92,7 +96,12 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) - private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled && conf.getBoolean("spark.shuffle.consolidateFiles", false) @@ -102,22 +111,17 @@ private[spark] class BlockManager( + " switch to sort-based shuffle.") } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. - private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) - } else { - blockManagerId - } + private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val appId = conf.get("spark.app.id", "unknown-app-id") - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, + securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -150,8 +154,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -170,16 +172,34 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) // Register Executors' configuration with the local shuffle service, if one should exist. @@ -206,7 +226,6 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - val attemptsRemaining = logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a2553979..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113ac..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ee89c7e521f4e..6b1f57a069431 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.BlockTransferService @@ -55,7 +56,7 @@ final class ShuffleBlockFetcherIterator( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { import ShuffleBlockFetcherIterator._ @@ -91,7 +92,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that @@ -118,16 +119,18 @@ final class ShuffleBlockFetcherIterator( private[this] def cleanup() { isZombie = true // Release the current buffer if necessary - if (currentResult != null && !currentResult.failed) { - currentResult.buf.release() + currentResult match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() - if (!result.failed) { - result.buf.release() + result match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } } } @@ -151,7 +154,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) shuffleMetrics.remoteBytesRead += buf.size shuffleMetrics.remoteBlocksFetched += 1 } @@ -160,7 +163,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FetchResult(BlockId(blockId), -1, null)) + results.put(new FailureFetchResult(BlockId(blockId), e)) } } ) @@ -231,12 +234,12 @@ final class ShuffleBlockFetcherIterator( val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() - results.put(new FetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(blockId, -1, null)) + results.put(new FailureFetchResult(blockId, e)) return } } @@ -267,15 +270,17 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Option[Iterator[Any]]) = { + override def next(): (BlockId, Try[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (!result.failed) { - bytesInFlight -= result.size + + result match { + case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case _ => } // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && @@ -283,20 +288,21 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { - None - } else { - val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream()) - val iter = serializer.newInstance().deserializeStream(is).asIterator - Some(CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - result.buf.release() - })) + val iteratorTry: Try[Iterator[Any]] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => { + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Success(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + })) + } } - (result.blockId, iteratorOpt) + (result.blockId, iteratorTry) } } @@ -315,14 +321,30 @@ object ShuffleBlockFetcherIterator { } /** - * Result of a fetch from a remote block. A failure is represented as size == -1. + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. -1 if failure is present. - * @param buf [[ManagedBuffer]] for the content. null is error. + * Note that this is NOT the exact bytes. + * @param buf [[ManagedBuffer]] for the content. */ - case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { - def failed: Boolean = size == -1 - if (failed) assert(buf == null) else assert(buf != null) + private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + extends FetchResult { + require(buf != null) + require(size >= 0) } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index f02904df31fcf..51dc08f668a43 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,6 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..e9c755e36f716 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => +
+
+ + Thread {thread.threadId}: {thread.threadName} ({thread.threadState}) + +
+ +
+ } + +
+

Updated at {UIUtils.formatDate(time)}

+ { + // scalastyle:off +

+ Expand All +

+

+ // scalastyle:on + } +
{dumpRows}
+
+ }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..048fee3ce1ff4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) Thread Dump else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { {Utils.bytesToString(info.totalShuffleWrite)} + { + if (threadDumpEnabled) { + + Thread Dump + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 9e0e71a51a408..ba97630f025c1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b5207360510dd..e3223403c17f4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 6e718eecdd52a..83a7898071c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
  • @@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") }} ++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ +

    Completed Stages ({numCompletedStages})

    ++ completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ +

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7cc03b7d333df..250bddbe2f262 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ @@ -112,6 +114,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Scheduler Delay +
  • + + + Task Deserialization Time + +
  • @@ -147,6 +156,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", TaskDetailsClassNames.GC_TIME), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ @@ -179,6 +189,17 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } } + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } @@ -266,6 +287,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val listings: Seq[Seq[Node]] = Seq( {serviceQuantiles}, {schedulerDelayQuantiles}, + + {deserializationQuantiles} + {gcQuantiles}, {serializationQuantiles} @@ -314,6 +338,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = info.gettingResultTime @@ -367,6 +392,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { class={TaskDetailsClassNames.SCHEDULER_DELAY}> {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} @@ -409,13 +438,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} }} - - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} - + {errorMessageCell(errorMessage)} } } + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = { if (info.gettingResultTime > 0) { @@ -424,6 +477,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { (info.finishTime - info.launchTime) } } - totalExecutionTime - metrics.executorRunTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4ee7f08ab47a2..3b4866e05956d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -195,7 +197,29 @@ private[ui] class FailedStageTable( override protected def stageRow(s: StageInfo): Seq[Node] = { val basicColumns = super.stageRow(s) - val failureReason =
    {s.failureReason.getOrElse("")}
    - basicColumns ++ failureReason + val failureReason = s.failureReason.getOrElse("") + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + val failureReasonHtml = {failureReasonSummary}{details} + basicColumns ++ failureReasonHtml } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 23d672cabda07..eb371bd0ea7ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -24,6 +24,7 @@ package org.apache.spark.ui.jobs private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val GC_TIME = "gc_time" + val TASK_DESERIALIZATION_TIME = "deserialization_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 79e398eb8c104..10010bdfa1a51 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def makeExecutorRef( + name: String, + conf: SparkConf, + host: String, + port: Int, + actorSystem: ActorSystem): ActorRef = { + val executorActorSystemName = SparkEnv.executorActorSystemName + Utils.checkHost(host, "Expected hostname") + val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index b6a099825f01b..390310243ee0a 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,10 +25,13 @@ private[spark] // scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { // scalastyle:on + + private[this] var completed = false def next() = sub.next() def hasNext = { val r = sub.hasNext - if (!r) { + if (!r && !completed) { + completed = true completion() } r diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 43c7fba06694a..f15d0c856663f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -279,13 +279,15 @@ private[spark] object JsonProtocol { ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ - ("Reduce ID" -> fetchFailed.reduceId) + ("Reduce ID" -> fetchFailed.reduceId) ~ + ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ + ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) case ExecutorLostFailure(executorId) => ("Executor ID" -> executorId) @@ -629,13 +631,17 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId) + val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, + message.getOrElse("Unknown reason")) case `exceptionFailure` => val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") + val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). + map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - new ExceptionFailure(className, description, stackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala new file mode 100644 index 0000000000000..d4e0ad93b966a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Used for shipping per-thread stacktraces from the executors to driver. + */ +private[spark] case class ThreadStackTrace( + threadId: Long, + threadName: String, + threadState: Thread.State, + stackTrace: String) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b402c5f334bb0..a14d6125484fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.util.jar.Attributes.Name @@ -44,6 +45,7 @@ import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -753,6 +755,7 @@ private[spark] object Utils extends Logging { /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. */ def deleteRecursively(file: File) { if (file != null) { @@ -1596,19 +1599,31 @@ private[spark] object Utils extends Logging { .orNull } - /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString(e: Exception): String = { - if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) + /** + * Return a nice string representation of the exception. It will call "printStackTrace" to + * recursively generate the stack trace including the exception and its causes. + */ + def exceptionString(e: Throwable): String = { + if (e == null) { + "" + } else { + // Use e.printStackTrace here because e.getStackTrace doesn't include the cause + val stringWriter = new StringWriter() + e.printStackTrace(new PrintWriter(stringWriter)) + stringWriter.toString + } } - /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString( - className: String, - description: String, - stackTrace: Array[StackTraceElement]): String = { - val desc = if (description == null) "" else description - val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") - s"$className: $desc\n$st" + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ + def getThreadDump(): Array[ThreadStackTrace] = { + // We need to filter out null values here because dumpAllThreads() may return null array + // elements for threads that are dead / don't exist. + val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) + threadInfos.sortBy(_.getThreadId).map { case threadInfo => + val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, + threadInfo.getThreadState, stackTrace) + } } /** @@ -1767,6 +1782,21 @@ private[spark] object Utils extends Logging { val manifest = new JarManifest(manifestUrl.openStream()) manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) }.getOrElse("Unknown") + + /** + * Return the value of a config either through the SparkConf or the Hadoop configuration + * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf + * if the key is not set in the Hadoop configuration. + */ + def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { + val sparkValue = conf.get(key, default) + if (SparkHadoopUtil.get.isYarnMode) { + SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) + } else { + sparkValue + } + } + } /** diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 792b9cd8b6ff2..6608ed1e57b38 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -63,8 +63,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { rdd.count() rdd.count() - // Invalidate the registered executors, disallowing access to their shuffle blocks. - rpcHandler.clearRegisteredExecutors() + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" // being set. diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala index 33bd1afea2470..48c386ba04311 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import java.io.{FileWriter, PrintWriter, File} class InputMetricsSuite extends FunSuite with SharedSparkContext { - test("input metrics when reading text file") { + test("input metrics when reading text file with single split") { val file = new File(getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(file)) pw.println("some stuff") @@ -48,6 +48,29 @@ class InputMetricsSuite extends FunSuite with SharedSparkContext { // Wait for task end events to come in sc.listenerBus.waitUntilEmpty(500) assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum == file.length()) + assert(taskBytesRead.sum >= file.length()) + } + + test("input metrics when reading text file with multiple splits") { + val file = new File(getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(file)) + for (i <- 0 until 10000) { + pw.println("some stuff") + } + pw.close() + file.deleteOnExit() + + val taskBytesRead = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + } + }) + sc.textFile("file://" + file.getAbsolutePath, 2).count() + + // Wait for task end events to come in + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesRead.length == 2) + assert(taskBytesRead.sum >= file.length()) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala new file mode 100644 index 0000000000000..9162ec9801663 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio._ +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.util.{Failure, Success, Try} + +import org.apache.commons.io.IOUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.{SecurityManager, SparkConf} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} + +class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { + test("security default off") { + testConnection(new SparkConf, new SparkConf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on same password") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on mismatch password") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Mismatched response") + } + } + + test("security mismatch auth off on server") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "false") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC + } + } + + test("security mismatch auth off on client") { + val conf0 = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "true") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Expected SaslMessage") + } + } + + /** + * Creates two servers with different configurations and sees if they can talk. + * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed + * properly. We will throw an out-of-band exception if something other than that goes wrong. + */ + private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = { + val blockManager = mock[BlockDataManager] + val blockId = ShuffleBlockId(0, 1, 2) + val blockString = "Hello, world!" + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + + val securityManager0 = new SecurityManager(conf0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0) + exec0.init(blockManager) + + val securityManager1 = new SecurityManager(conf1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1) + exec1.init(blockManager) + + val result = fetchBlock(exec0, exec1, "1", blockId) match { + case Success(buf) => + IOUtils.toString(buf.createInputStream()) should equal(blockString) + buf.release() + Success() + case Failure(t) => + Failure(t) + } + exec0.close() + exec1.close() + result + } + + /** Synchronously fetches a single block, acting as the given executor fetching from another. */ + private def fetchBlock( + self: BlockTransferService, + from: BlockTransferService, + execId: String, + blockId: BlockId): Try[ManagedBuffer] = { + + val promise = Promise[ManagedBuffer]() + + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + promise.failure(exception) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + promise.success(data.retain()) + } + }) + + Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + promise.future.value.get + } +} + diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index b70734dfe37cf..716f875d30b8a 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") + conf.set("spark.app.id", "app-id") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) var numReceivedMessages = 0 @@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite { test("security mismatch password") { val conf = new SparkConf conf.set("spark.authenticate", "true") + conf.set("spark.app.id", "app-id") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) @@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite { None }) - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") + val badconf = conf.clone.set("spark.authenticate.secret", "bad") val badsecurityManager = new SecurityManager(badconf) val managerServer = new ConnectionManager(0, badconf, badsecurityManager) var numReceivedServerMessages = 0 diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a2e4f712db55b..819f95634bcdc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -431,7 +431,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -461,7 +461,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), null, @@ -472,7 +472,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), null, @@ -624,7 +624,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -655,7 +655,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostB", 1)))) // pretend stage 0 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c6d7105592096..f63e772bf1e59 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) + store.initialize("app-id") allStores += store store } @@ -262,7 +263,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) + failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 715b740b857b2..9529502bc8e10 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr) + manager.initialize("app-id") + manager } before { @@ -793,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 28f766570e96f..1eaabb93adbed 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -102,7 +102,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, subIterator) = iterator.next() - assert(subIterator.isDefined, + assert(subIterator.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release the buffer once the iterator is exhausted. @@ -230,8 +230,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { sem.acquire() // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isDefined === true) - assert(iterator.next()._2.isDefined === false) - assert(iterator.next()._2.isDefined === false) + assert(iterator.next()._2.isSuccess) + assert(iterator.next()._2.isFailure) + assert(iterator.next()._2.isFailure) } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 6567c5ab836e7..2608ad4b32e1e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -115,8 +115,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // Go through all the failure cases to make sure we are counting them as failures. val taskFailedReasons = Seq( Resubmitted, - new FetchFailed(null, 0, 0, 0), - new ExceptionFailure("Exception", "description", null, None), + new FetchFailed(null, 0, 0, 0, "ignored"), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala new file mode 100644 index 0000000000000..3755d43e25ea8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite + +class CompletionIteratorSuite extends FunSuite { + test("basic test") { + var numTimesCompleted = 0 + val iter = List(1, 2, 3).iterator + val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 }) + + assert(completionIter.hasNext) + assert(completionIter.next() === 1) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 2) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 3) + assert(numTimesCompleted === 0) + + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + + // SPARK-4264: Calling hasNext should not trigger the completion callback again. + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index d235d7a0ed839..39e69851e7e3c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -107,8 +107,9 @@ class JsonProtocolSuite extends FunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19) - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Some exception") + val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -126,6 +127,13 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("ExceptionFailure backward compatibility") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) + .removeField({ _._1 == "Full Stack Trace" }) + assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("StageInfo backward compatibility") { val info = makeStageInfo(1, 2, 3, 4L, 5L) val newJson = JsonProtocol.stageInfoToJson(info) @@ -176,6 +184,17 @@ class JsonProtocolSuite extends FunSuite { deserializedBmRemoved) } + test("FetchFailed backwards compatibility") { + // FetchFailed in Spark 1.1.0 does not have an "Message" property. + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "ignored") + val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) + .removeField({ _._1 == "Message" }) + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Unknown reason") + assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -184,6 +203,15 @@ class JsonProtocolSuite extends FunSuite { assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } + test("ExecutorLostFailure backward compatibility") { + // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. + val executorLostFailure = ExecutorLostFailure("100") + val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) + .removeField({ _._1 == "Executor ID" }) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -396,10 +424,12 @@ class JsonProtocolSuite extends FunSuite { assert(r1.mapId === r2.mapId) assert(r1.reduceId === r2.reduceId) assertEquals(r1.bmAddress, r2.bmAddress) + assert(r1.message === r2.message) case (r1: ExceptionFailure, r2: ExceptionFailure) => assert(r1.className === r2.className) assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) + assert(r1.fullStackTrace === r2.fullStackTrace) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => diff --git a/dev/run-tests b/dev/run-tests index 0e9eefa76a18b..de607e4344453 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/docs/building-spark.md b/docs/building-spark.md index 4cc0b1f2e5116..238ddae15545e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -99,14 +99,11 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} - - # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` profile to your existing build options. By default Spark will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only -supported for Hive 0.12.0. +the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package @@ -121,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package - mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -185,16 +182,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test # Speeding up Compilation with Zinc diff --git a/docs/configuration.md b/docs/configuration.md index 685101ea5c9c9..0f9eb81f6e993 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c6563944..d5b044d94fdd7 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 10a5131c07414..ca8c29218f52d 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
    +[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
    + ## Random data generation diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d4ade939c3a6e..e399fecbbc78c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8bbba88b31978..44a1f3ad7560b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e51..4aa908242eeaa 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0d6b82b4944f3..a5396c2375915 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -40,9 +40,11 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -583,10 +585,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -723,6 +738,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java new file mode 100644 index 0000000000000..1af2067b2b929 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTrees { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTrees " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.weakLearnerParams().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 49751a30491d0..63f02cf7b98b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -154,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -205,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -229,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -338,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..9b6db01448be0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 8796c28db8a66..91a0a860d6c71 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -106,9 +106,11 @@ object MovieLensALS { Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..2b6137be25547 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { srcIds(i) = edgeArray(i).srcId diff --git a/mllib/pom.xml b/mllib/pom.xml index fb7239e779aae..87a7ddaba97f2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index acdc67ddc660a..d832ae34b55e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -454,6 +455,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method @@ -736,7 +762,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483f..078fbfbe4f0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..ac217edc619ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14d..57c0768084e41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f001..35e81fcb3de0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala index 1a847201ce157..f729344a682e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -17,30 +17,49 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ - +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements gradient boosting for regression and binary classification problems. + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * * @param boostingStrategy Parameters for the gradient boosting algorithm */ @Experimental class GradientBoosting ( private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -51,6 +70,7 @@ class GradientBoosting ( algo match { case Regression => GradientBoosting.boost(input, boostingStrategy) case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoosting.boost(remappedInput, boostingStrategy) case _ => @@ -118,120 +138,32 @@ object GradientBoosting extends Logging { } /** - * Method to train a gradient boosting binary classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] */ - def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) } - /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -247,15 +179,17 @@ object GradientBoosting extends Logging { timer.start("init") // Initialize gradient boosting parameters - val numEstimators = boostingStrategy.numEstimators - val baseLearners = new Array[DecisionTreeModel](numEstimators) - val baseLearnerWeights = new Array[Double](numEstimators) + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate val strategy = boostingStrategy.weakLearnerParams // Cache input - input.persist(StorageLevel.MEMORY_AND_DISK) + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } timer.stop("init") @@ -264,7 +198,7 @@ object GradientBoosting extends Logging { logDebug("##########") var data = input - // 1. Initialize tree + // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(strategy).train(data) baseLearners(0) = firstTreeModel @@ -280,7 +214,7 @@ object GradientBoosting extends Logging { point.features)) var m = 1 - while (m < numEstimators) { + while (m < numIterations) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -289,6 +223,9 @@ object GradientBoosting extends Logging { timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), @@ -305,8 +242,6 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - // 3. Output classifier new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 501d9ff9ea9b7..abbda040bd528 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,6 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** @@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. * @param numClassesForClassification Number of classes for classification. * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. * Default value is 2 (binary classification). - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - algo: Algo, - @BeanProperty var numEstimators: Int, + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var subsamplingRate: Double = 1.0, @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + - s"$learningRate.") - require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + - s"$learningRate.") - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo weakLearnerParams.numClassesForClassification = numClassesForClassification - weakLearnerParams.subsamplingRate = subsamplingRate + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } } @Experimental @@ -82,28 +93,17 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = defaultWeakLearnerParams(algo) + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 algo match { - case Classification => - new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) - case Regression => - new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } } - - /** - * Returns default configuration for the weak learner (decision tree) algorithm - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @return Configuration for weak learner - */ - def defaultWeakLearnerParams(algo: Algo): Strategy = { - // Note: Regression tree used even for classification for GBT. - new Strategy(Regression, Variance, 3) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d09295c507d67..b5b1f82177edc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class Strategy ( - val algo: Algo, + @BeanProperty var algo: Algo, @BeanProperty var impurity: Impurity, @BeanProperty var maxDepth: Int, @BeanProperty var numClassesForClassification: Int = 2, @@ -85,17 +85,9 @@ class Strategy ( @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -112,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -143,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..93a84fe07b32a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242b..4ef67a40b9f49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala index 970fff82215e2..99a02eda60baf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -22,9 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.{Variance, Gini} +import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.LocalSparkContext @@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, 1, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } } - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) @@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { object GradientBoostingSuite { // Combinations for estimators, learning rates and subsamplingRate - val testCombinations - = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index c0a62e00432a3..5cb433232e714 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) baggedRDD.collect().foreach { baggedPoint => assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) } @@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) diff --git a/network/common/pom.xml b/network/common/pom.xml index ea887148d98ba..6144548a8f998 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -50,6 +50,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e56c..5bc6e5a2418a9 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc63903..5fa1527ddff92 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; /** * A {@link ManagedBuffer} backed by a segment in a file. @@ -101,7 +102,7 @@ public InputStream createInputStream() throws IOException { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff423c..4e944114e8176 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -18,11 +18,12 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -117,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -148,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); - callback.onFailure(new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -176,6 +185,8 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); } catch (Exception e) { throw Throwables.propagate(e); } @@ -186,4 +197,12 @@ public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..65e8020e34121 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d8286407..397d3a8455c86 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -18,13 +18,17 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -47,22 +51,29 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List clientBootstraps; private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,21 +83,26 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) { + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); if (cachedClient != null) { if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); return cachedClient; } else { + logger.info("Found inactive connection to {}, closing it.", address); connectionPool.remove(address, cachedClient); // Remove inactive clients. } } @@ -104,34 +120,55 @@ public TransportClient createClient(String remoteHost, int remotePort) { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new RuntimeException( + throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { - throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); + } + + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); } + long postBootstrap = System.currentTimeMillis(); - // Successful connection -- in the event that two threads raced to create a client, we will + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } @@ -162,7 +199,7 @@ public void close() { */ private PooledByteBufAllocator createPooledByteBufAllocator() { return new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), getPrivateStaticField("DEFAULT_PAGE_SIZE"), diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d8965590b34da..2044afb0d85db 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.client; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +95,7 @@ public void channelUnregistered() { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 4cb8becc3ed22..91d1e8a538a77 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; - ByteBuf header = ctx.alloc().buffer(headerLength); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); in.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726fc1..1502b7489e864 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc6203944..2ba92a40f8b0a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,22 +23,33 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b35cc..1580180cc17e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 70da48ca8ee79..579676c2c3564 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -71,11 +72,14 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 40b71b0c87a47..75c4a3981a240 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,17 +17,27 @@ package org.apache.spark.network.util; +import java.nio.ByteBuffer; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import com.google.common.base.Preconditions; import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); @@ -73,4 +83,73 @@ public static int nonNegativeHash(Object obj) { int hash = obj.hashCode(); return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000000..63ca43c046525 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index b1872341198e0..2a7664fe89388 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -37,13 +37,17 @@ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. */ public class NettyUtils { - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - - ThreadFactory threadFactory = new ThreadFactoryBuilder() + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") + .setNameFormat(threadPoolPrefix + "-%d") .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); switch (mode) { case NIO: diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e94c9..787a8f0031af1 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) { /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + /** Connect timeout in secs. Default 120 secs. */ public int connectionTimeoutMs() { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; @@ -55,4 +60,19 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 5a10fdb3842ef..822bef1d81b2a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.util.concurrent.TimeoutException; import org.junit.After; @@ -57,7 +58,7 @@ public void tearDown() { } @Test - public void createAndReuseBlockClients() throws TimeoutException { + public void createAndReuseBlockClients() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException { } @Test - public void neverReturnInactiveClients() throws Exception { + public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception { } @Test - public void closeBlockClientsWithFactory() throws TimeoutException { + public void closeBlockClientsWithFactory() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index d271704d98a7a..fe5681d463499 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -51,6 +51,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000000..7bc91e375371f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..599cc6428c90e --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + // tag + appIdLength + appId + payloadLength + payload + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(idBytes); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..3777a18e33f78 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..81d5766794688 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 0000000000000..351c7930a900f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..9abad1f30a259 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..e87b17ead1e1a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31decc83..75ebf8c7b0604 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; @@ -94,9 +94,11 @@ public StreamManager getStreamManager() { return streamManager; } - /** For testing, clears all executors registered with "RegisterExecutor". */ - @VisibleForTesting - public void clearRegisteredExecutors() { - blockManager.clearRegisteredExecutors(); + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 6589889fe1be7..98fcfb82aa5d1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -21,9 +21,15 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,13 +49,22 @@ public class ExternalShuffleBlockManager { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); - // Map from "appId-execId" to the executor's configuration. - private final ConcurrentHashMap executors = - new ConcurrentHashMap(); + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; - // Returns an id suitable for a single executor within a single application. - private String getAppExecId(String appId, String execId) { - return appId + "-" + execId; + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + public ExternalShuffleBlockManager() { + // TODO: Give this thread a name. + this(Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(Executor directoryCleaner) { + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; } /** Registers a new Executor with all the configuration we need to find its shuffle files. */ @@ -57,7 +72,7 @@ public void registerExecutor( String appId, String execId, ExecutorShuffleInfo executorInfo) { - String fullId = getAppExecId(appId, execId); + AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); executors.put(fullId, executorInfo); } @@ -78,7 +93,7 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { int mapId = Integer.parseInt(blockIdParts[2]); int reduceId = Integer.parseInt(blockIdParts[3]); - ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); @@ -94,6 +109,56 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { } } + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + /** * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockManager. @@ -146,9 +211,36 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } - /** For testing, clears all registered executors. */ - @VisibleForTesting - void clearRegisteredExecutors() { - executors.clear(); + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc44b958b..27884b82c8cb9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,14 +17,19 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; +import java.io.IOException; +import java.util.List; +import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import org.apache.spark.network.util.JavaUtils; @@ -36,30 +41,69 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - private final TransportClientFactory clientFactory; - private final String appId; + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; - public ExternalShuffleClient(TransportConf conf, String appId) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); - this.clientFactory = context.createClientFactory(); + private TransportClientFactory clientFactory; + private String appId; + + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; + } + + @Override + public void init(String appId) { this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); } @Override public void fetchBlocks( - String host, - int port, - String execId, + final String host, + final int port, + final String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } } catch (Exception e) { logger.error("Exception while beginning fetchBlocks", e); for (String blockId : blockIds) { @@ -81,7 +125,8 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) { + ExecutorShuffleInfo executorInfo) throws IOException { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 39b6f30f92baf..9e77a1f68c4b0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -51,9 +51,6 @@ public OneForOneBlockFetcher( TransportClient client, String[] blockIds, BlockFetchingListener listener) { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } this.client = client; this.blockIds = blockIds; this.listener = listener; @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start(Object openBlocksMessage) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } } catch (Exception e) { - logger.error("Failed while starting block fetches", e); + logger.error("Failed while starting block fetches after success", e); failRemainingBlocks(blockIds, e); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 0000000000000..f8a1a266863bb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a562394557..f72ab40690d0d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..d25283e46ef96 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000000..c8ece3bc53ac3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5fd68e73..3bea5b0f253c6 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -105,7 +105,7 @@ public static void afterAll() { @After public void afterEach() { - handler.clearRegisteredExecutors(); + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); } class FetchResult { @@ -135,7 +135,8 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -164,6 +165,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } @@ -257,15 +259,22 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000000..848c88f743d50 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() throws IOException { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) throws IOException { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000000..0191fe529e1be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 442b756467442..337b5c7bdb5da 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -30,8 +30,8 @@ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. */ public class TestShuffleDataContext { - private final String[] localDirs; - private final int subDirsPerLocalDir; + public final String[] localDirs; + public final int subDirsPerLocalDir; public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { this.localDirs = new String[numLocalDirs]; diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 0000000000000..e60d8c1f7876c --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project Yarn Shuffle Service Code + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 0000000000000..bb0b8f7e6cba6 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 0000000000000..884861752e80d --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index 6191cd3a541e2..88ef67c515b3a 100644 --- a/pom.xml +++ b/pom.xml @@ -1229,6 +1229,7 @@ yarn-alpha yarn + network/yarn @@ -1236,6 +1237,7 @@ yarn yarn + network/yarn @@ -1359,7 +1361,7 @@ false - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33618f5401768..657e4b4432775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,9 +38,9 @@ object BuildCommons { "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) @@ -143,7 +143,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5f8dcedb1eea2..faa5952258aef 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -63,7 +63,6 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, @@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer - elif batchSize == 0: + if batchSize == 0: self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, @@ -305,12 +302,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +321,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -396,6 +388,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: @@ -405,7 +427,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -427,17 +449,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -458,18 +478,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -487,18 +505,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -519,18 +535,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -548,15 +562,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -836,7 +848,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 76864d8163586..c6149fe391ec8 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -96,10 +96,15 @@ def _java2py(sc, r): if clsName == 'JavaRDD': jrdd = sc._jvm.SerDe.javaToPython(r) - return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer())) + return RDD(jrdd, sc) - elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable if isinstance(r, bytearray): r = PickleSerializer().loads(str(r)) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index d0a0e102a1a07..e35202dca0acc 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,8 +29,11 @@ import numpy as np +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] + +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -106,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ @@ -528,6 +578,8 @@ class DenseMatrix(Matrix): def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) assert len(values) == numRows * numCols + if not isinstance(values, array.array): + values = array.array('d', values) self.values = values def __reduce__(self): @@ -546,6 +598,15 @@ def toArray(self): return np.reshape(self.values, (self.numRows, self.numCols), order='F') +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6b32af07c9be2..e8b998414d319 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -117,7 +117,7 @@ def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 15f0652f833d7..0700f8a8e5a8e 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,11 +19,12 @@ Python package for statistical functions in MLlib. """ +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import Matrix, _convert_to_vector -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -51,6 +52,54 @@ def min(self): return self.call("min").toArray() +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() + + class Statistics(object): @staticmethod @@ -135,6 +184,90 @@ def corr(x, y=None, method=None): else: return callMLlibFunc("corr", x.map(float), y.map(float), method) + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> from pyspark.mllib.regression import LabeledPoint + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) + else: + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) + def _test(): import doctest diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4a..9fa4d6f6a2f5f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 550c9dd80522f..879655dc53f4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -316,9 +312,6 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). - - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -449,12 +442,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -1123,9 +1115,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1150,9 +1141,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1169,9 +1159,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1198,9 +1187,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1218,9 +1206,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): :param path: path to sequence file :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1235,8 +1222,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1777,13 +1767,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1793,12 +1780,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1937,25 +1928,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2135,7 +2115,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 528a181e8905a..f5c3cfd259a5b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None): def initRandomGenerator(self, split): if self._use_numpy: import numpy - self._random = numpy.random.RandomState(self._seed) + self._random = numpy.random.RandomState(self._seed ^ split) else: - self._random = random.Random(self._seed) + self._random = random.Random(self._seed ^ split) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, 2 ** 32 - 1) + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 904bd9f2652d3..d597cbf94e1b1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,7 +62,7 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): @@ -113,7 +102,7 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ def __hash__(self): return hash(str(self)) @@ -181,6 +170,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -213,10 +203,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) def __repr__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -225,7 +215,7 @@ class AutoBatchedSerializer(BatchedSerializer): """ def __init__(self, serializer, bestSize=1 << 16): - BatchedSerializer.__init__(self, serializer, -1) + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -248,10 +238,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "AutoBatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -284,7 +274,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "CartesianDeserializer<%s, %s>" % \ + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -311,7 +301,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -430,7 +420,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index d57a802e4734a..5931e923c2e36 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer try: import psutil @@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -470,7 +469,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) def _get_path(self, n): """ Choose one directory for spill by number n """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 98e41f8575679..e5d62a466cab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -44,7 +44,8 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -109,6 +110,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -331,7 +341,7 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable, metadata=None): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. @@ -408,6 +418,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -460,6 +539,12 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -479,11 +564,18 @@ def _parse_datatype_json_value(json_value): else: raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -499,23 +591,34 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ if obj is None: raise ValueError("Can not infer type for None") + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -548,60 +651,180 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") - - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] - - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv - - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) + raise ValueError("Unexpected obj: %s" % obj) - return nested_conv + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -713,7 +936,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -775,11 +998,22 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -854,6 +1088,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -924,6 +1160,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -995,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -1025,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -1049,18 +1287,20 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1097,8 +1337,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1184,8 +1439,11 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1219,7 +1477,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1227,8 +1485,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1274,20 +1532,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1344,7 +1602,7 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) @@ -1582,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -1812,15 +2070,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( @@ -1828,6 +2084,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b3a7..9f625c5c6ca48 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -241,7 +242,7 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() @@ -252,7 +253,7 @@ class ReusedPySparkTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + cls.sc = SparkContext('local[4]', cls.__name__) @classmethod def tearDownClass(cls): @@ -648,6 +649,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): @@ -655,7 +671,7 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -679,8 +695,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -781,6 +854,25 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) @@ -790,6 +882,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): @@ -887,16 +1012,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -982,6 +1110,25 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) + class OutputFormatTests(ReusedPySparkTestCase): @@ -1216,51 +1363,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 00fc4d75c9ea9..5e613e0f18ba6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -242,6 +242,7 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 2059a91ba0612..0415d74bd8141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,6 +28,8 @@ trait Catalog { def caseSensitive: Boolean + def tableExists(db: Option[String], tableName: String): Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -82,6 +84,14 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables.clear() } + override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + tables.get(tblName) match { + case Some(_) => true + case None => false + } + } + override def lookupRelation( databaseName: Option[String], tableName: String, @@ -107,6 +117,14 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + abstract override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + overrides.get((dbName, tblName)) match { + case Some(_) => true + case None => super.tableExists(db, tableName) + } + } + abstract override def lookupRelation( databaseName: Option[String], tableName: String, @@ -149,6 +167,10 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true + def tableExists(db: Option[String], tableName: String): Boolean = { + throw new UnsupportedOperationException + } + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7e6d770314f5a..84fb594da372d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -109,7 +110,8 @@ package object dsl { def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) - def as(s: Symbol) = Alias(expr, s.name)() + def as(alias: String) = Alias(expr, alias)() + def as(alias: Symbol) = Alias(expr, alias.name)() } trait ExpressionConversions { @@ -146,6 +148,44 @@ package object dsl { def upper(e: Expression) = Upper(e) def lower(e: Expression) = Lower(e) + /* + * Conversions to provide the standard operators in the special case + * where a literal is being combined with a symbol. Without these an + * expression such as 0 < 'x is not recognized. + */ + class LhsLiteral(x: Any) { + val literal = Literal(x) + def + (other: Symbol) = Add(literal, other) + def - (other: Symbol) = Subtract(literal, other) + def * (other: Symbol) = Multiply(literal, other) + def / (other: Symbol) = Divide(literal, other) + def % (other: Symbol) = Remainder(literal, other) + + def && (other: Symbol) = And(literal, other) + def || (other: Symbol) = Or(literal, other) + + def < (other: Symbol) = LessThan(literal, other) + def <= (other: Symbol) = LessThanOrEqual(literal, other) + def > (other: Symbol) = GreaterThan(literal, other) + def >= (other: Symbol) = GreaterThanOrEqual(literal, other) + def === (other: Symbol) = EqualTo(literal, other) + def <=> (other: Symbol) = EqualNullSafe(literal, other) + def !== (other: Symbol) = Not(EqualTo(literal, other)) + } + + implicit def booleanToLhsLiteral(b: Boolean) = new LhsLiteral(b) + implicit def byteToLhsLiteral(b: Byte) = new LhsLiteral(b) + implicit def shortToLhsLiteral(s: Short) = new LhsLiteral(s) + implicit def intToLhsLiteral(i: Int) = new LhsLiteral(i) + implicit def longToLhsLiteral(l: Long) = new LhsLiteral(l) + implicit def floatToLhsLiteral(f: Float) = new LhsLiteral(f) + implicit def doubleToLhsLiteral(d: Double) = new LhsLiteral(d) + implicit def stringToLhsLiteral(s: String) = new LhsLiteral(s) + implicit def bigDecimalToLhsLiteral(d: BigDecimal) = new LhsLiteral(d) + implicit def decimalToLhsLiteral(d: Decimal) = new LhsLiteral(d) + implicit def dateToLhsLiteral(d: Date) = new LhsLiteral(d) + implicit def timestampToLhsLiteral(t: Timestamp) = new LhsLiteral(t) + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -285,4 +325,62 @@ package object dsl { def writeToFile(path: String) = WriteToFile(path, logicalPlan) } } + + case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { + def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } + + // scalastyle:off + /** functionToUdfBuilder 1-22 were generated by this script + + (1 to 22).map { x => + val argTypes = Seq.fill(x)("_").mkString(", ") + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + } + */ + + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fa1786e74bb3e..18c96da2f87fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -34,320 +34,366 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 570379c533e1f..9f977bf6c2a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ +import java.sql.{Date, Timestamp} /** * A parent class for mutable container objects that are reused when the values are changed, @@ -169,6 +170,35 @@ final class MutableByte extends MutableValue { newCopy.asInstanceOf[this.type] } } +final class MutableDate extends MutableValue { + var value: Date = new Date(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Date] + } + def copy() = { + val newCopy = new MutableDate + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableTimestamp extends MutableValue { + var value: Timestamp = new Timestamp(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Timestamp] + } + def copy() = { + val newCopy = new MutableTimestamp + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} final class MutableAny extends MutableValue { var value: Any = _ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cc5015ad3c013..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -71,6 +71,8 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), ("type", JString("udt"))) => Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } @@ -213,7 +215,7 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap /** Given the string representation of a type, return its DataType */ @@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + /** * Convert the user type to a SQL datum * @@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ - ("class" -> this.getClass.getName) + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala new file mode 100644 index 0000000000000..fcb77e640a8ea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import scala.language.implicitConversions + +/* + * Subclass of java.sql.Date which provides the usual comparison + * operators (as required for catalyst expressions) and which can + * be constructed from a string. + * + * scala> val d1 = Date("2014-02-01") + * d1: Date = 2014-02-01 + * + * scala> val d2 = Date("2014-02-02") + * d2: Date = 2014-02-02 + * + * scala> d1 < d2 + * res1: Boolean = true + */ + +class RichDate(milliseconds: Long) extends Date(milliseconds) { + def <(that: Date): Boolean = this.before(that) + def >(that: Date): Boolean = this.after(that) + def <=(that: Date): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Date): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Date): Boolean = this.equals(that) +} + +object RichDate { + def apply(init: String) = new RichDate(Date.valueOf(init).getTime) +} + +/* + * Analogous subclass of java.sql.Timestamp. + * + * scala> val ts1 = Timestamp("2014-03-04 12:34:56.12") + * ts1: Timestamp = 2014-03-04 12:34:56.12 + * + * scala> val ts2 = Timestamp("2014-03-04 12:34:56.13") + * ts2: Timestamp = 2014-03-04 12:34:56.13 + * + * scala> ts1 < ts2 + * res13: Boolean = true + */ + +class RichTimestamp(milliseconds: Long) extends Timestamp(milliseconds) { + def <(that: Timestamp): Boolean = this.before(that) + def >(that: Timestamp): Boolean = this.after(that) + def <=(that: Timestamp): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Timestamp): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Timestamp): Boolean = this.equals(that) +} + +object RichTimestamp { + def apply(init: String) = new RichTimestamp(Timestamp.valueOf(init).getTime) +} + +/* + * Implicit conversions. + */ + +object TimeConversions { + + implicit def javaDateToRichDate(jdate: Date): RichDate = { + new Date(jdate.getTime) + } + + implicit def javaTimestampToRichTimestamp(jtimestamp: Timestamp): RichTimestamp = { + new Timestamp(jtimestamp.getTime) + } + + implicit def richDateToJavaDate(date: RichDate): Date = { + new Date(date.getTime) + } + + implicit def richTimestampToJavaTimestamp(timestamp: RichTimestamp): Timestamp = { + new Timestamp(timestamp.getTime) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6bfa0dbd65ba7..472e97cad79b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -32,6 +32,13 @@ import org.apache.spark.sql.catalyst.types._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class ExpressionEvaluationSuite extends FunSuite { test("literals") { @@ -318,18 +325,18 @@ class ExpressionEvaluationSuite extends FunSuite { intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) + assert(EQ(("abcdef" cast StringType).nullable).===(false)) + assert(EQ(("abcdef" cast BinaryType).nullable).===(false)) + assert(EQ(("abcdef" cast BooleanType).nullable).===(false)) + assert(EQ(("abcdef" cast TimestampType).nullable).===(true)) + assert(EQ(("abcdef" cast LongType).nullable).===(true)) + assert(EQ(("abcdef" cast IntegerType).nullable).===(true)) + assert(EQ(("abcdef" cast ShortType).nullable).===(true)) + assert(EQ(("abcdef" cast ByteType).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType.Unlimited).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType(4, 2)).nullable).===(true)) + assert(EQ(("abcdef" cast DoubleType).nullable).===(true)) + assert(EQ(("abcdef" cast FloatType).nullable).===(true)) checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } @@ -346,15 +353,15 @@ class ExpressionEvaluationSuite extends FunSuite { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + assert(EQ(Cast(Literal(123), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03f), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable).===(false)) - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + assert(EQ(Cast(Literal(123), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03f), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable).===(true)) checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) @@ -500,26 +507,26 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5)).nullable).===(true)) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable).===(true)) } test("complex type") { @@ -559,11 +566,11 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(EQ(GetField(BoundReference(2,typeS, nullable = true), "a").nullable).===(true)) + assert(EQ(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable).===(false)) - assert(GetField(Literal(null, typeS), "a").nullable === true) - assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(EQ(GetField(Literal(null, typeS), "a").nullable).===(true)) + assert(EQ(GetField(Literal(null, typeS_notNullable), "a").nullable).===(true)) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -717,10 +724,10 @@ class ExpressionEvaluationSuite extends FunSuite { val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(EQ(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(false)) + assert(EQ(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable).===(true)) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -772,4 +779,25 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + test("recognizes literals on the left") { + assert(EQ(-1 + 'x).===(Add(-1, 'x))) + assert(EQ(0 < 'x).===(LessThan(0, 'x))) + assert(EQ(1.5 === 'x).===(EqualTo(1.5, 'x))) + assert(EQ(false !== 'x).===(Not(EqualTo(false, 'x)))) + assert(EQ("a string" >= 'x).===(GreaterThanOrEqual("a string", 'x))) + assert(EQ(RichDate("2014-11-05") > 'date).===(GreaterThan(RichDate("2014-11-05"), 'date))) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.789") < 'now).===( + LessThan(RichTimestamp("2014-11-05 12:34:56.789"), 'now))) + } + + test("comparison operators for RichDate and RichTimestamp") { + assert(EQ(RichDate("2014-11-05") < RichDate("2014-11-06")).===(true)) + assert(EQ(RichDate("2014-11-05") <= RichDate("2013-11-06")).===(false)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.5432") > RichTimestamp("2014-11-05 00:00:00") + ).===(true)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56") >= RichTimestamp("2014-11-06 00:00:00") + ).===(false)) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 036fd3fa1d6a1..3586e6557aa14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq def nullable = true @@ -36,21 +43,21 @@ case class Dummy(optKey: Option[Expression]) extends Expression { class TreeNodeSuite extends FunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } - assert(after === Literal(2)) + assert(EQ(after).===(Literal(2))) } test("one child changed") { val before = Add(Literal(1), Literal(2)) val after = before transform { case Literal(2, _) => Literal(1) } - assert(after === Add(Literal(1), Literal(1))) + assert(EQ(after).===(Add(Literal(1), Literal(1)))) } test("no change") { val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val after = before transform { case Literal(5, _) => Literal(1)} - assert(before === after) + assert(EQ(before).===(after)) // Ensure that the objects after are the same objects before the transformation. before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach { case (b, a) => assert(b eq a) @@ -61,7 +68,7 @@ class TreeNodeSuite extends FunSuite { val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val literals = tree collect {case l: Literal => l} - assert(literals.size === 4) + assert(EQ(literals.size).===(4)) (1 to 4).foreach(i => assert(literals contains Literal(i))) } @@ -74,7 +81,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("post-order transform") { @@ -86,7 +93,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("transform works on nodes with Option children") { @@ -95,13 +102,13 @@ class TreeNodeSuite extends FunSuite { val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy1 transformUp toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy2 transform toZero - assert(actual === Dummy(None)) + assert(EQ(actual).===(Dummy(None))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 07e6e2eccddf4..279495aa64755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -79,13 +79,13 @@ private[sql] trait SQLConf { private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -106,10 +106,10 @@ private[sql] trait SQLConf { * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9e61d18f7e926..98370dcf81713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation @@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } @@ -500,4 +502,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } + + /* * + * Map RichDate and RichTimestamp to their expected names in this context. + */ + val Date = org.apache.spark.sql.catalyst.expressions.RichDate + val Timestamp = org.apache.spark.sql.catalyst.expressions.RichTimestamp } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3ee2ea05cfa2d..fbec2f9f4b2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{List => JList} +import org.apache.spark.api.python.SerDeUtil + import scala.collection.JavaConversions._ import net.razorvine.pickle.Pickler @@ -385,12 +387,8 @@ class SchemaRDD( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 1e0ccb368a276..78e8d908fe0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -47,6 +47,9 @@ class JavaSchemaRDD( private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + /** Returns the underlying Scala SchemaRDD. */ + val schemaRDD: SchemaRDD = baseSchemaRDD + override val classTag = scala.reflect.classTag[Row] override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 997669051ed07..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal // Pyrolite can handle Timestamp @@ -177,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 51dad54f1a3f3..a1b3a221978f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -460,4 +460,37 @@ package object sql { */ @DeveloperApi type MetadataBuilder = catalyst.util.MetadataBuilder + + /** + * :: DeveloperApi :: + * + * A Date class which support the standard comparison operators, for + * use in DSL expressions. Implicit conversions to java.sql.Date + * are provided. The class intializer accepts a String, e.g. + * + * {{{ + * val d = RichDate("2014-01-01") + * }}} + * + * @group dataType + */ + @DeveloperApi + val RichDate = catalyst.expressions.RichDate + + /** + * :: DeveloperApi :: + * + * A Timestamp class which support the standard comparison + * operators, for use in DSL expressions. Implicit conversions to + * java.sql.timestamp are provided. The class intializer accepts a + * String, e.g. + * + * {{{ + * val ts = RichTimestamp("2014-01-01 12:34:56.78") + * }}} + * + * @group timeClasses + */ + @DeveloperApi + val RichTimestamp = catalyst.expressions.RichTimestamp } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 9664c565a0b86..d00860a8bb8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -123,7 +123,7 @@ case class ParquetTableScan( // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( @@ -394,7 +394,7 @@ private[parquet] class FilteringParquetRowInputFormat if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap if (statuses.isEmpty) { @@ -493,7 +493,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter; import parquet.filter2.compat.RowGroupFilter; - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..b9569e96c0312 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 1bc15146f0fe8..3fa4a7c6481d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.UserDefinedType - protected[sql] object DataTypeConversions { /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 765fa82776341..bc8efd787b1dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,13 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class BigData(s: String) class CachedTableSuite extends QueryTest { @@ -74,7 +81,7 @@ class CachedTableSuite extends QueryTest { val data = "*" * 10000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) + assert(EQ(table("bigData").count()).===(200000L)) table("bigData").unpersist(blocking = true) } @@ -228,7 +235,7 @@ class CachedTableSuite extends QueryTest { table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum - assert(cached.statistics.sizeInBytes === actualSizeInBytes) + assert(EQ(cached.statistics.sizeInBytes).===(actualSizeInBytes)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a2..d09f4296690e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,14 +19,29 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.test.TestSQLContext._ + +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + + +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class DslQuerySuite extends QueryTest { - import TestData._ + import org.apache.spark.sql.TestData._ test("table scan") { checkAnswer( @@ -172,7 +187,7 @@ class DslQuerySuite extends QueryTest { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(EQ(testData2.count()).===(testData2.map(_ => 1).count())) } test("null count") { @@ -193,7 +208,7 @@ class DslQuerySuite extends QueryTest { } test("zero count") { - assert(emptyTableData.count() === 0) + assert(EQ(emptyTableData.count()).===(0)) } test("except") { @@ -216,4 +231,14 @@ class DslQuerySuite extends QueryTest { (4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + + test("udf") { + val foo = (a: Int, b: String) => a.toString + b + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select(Star(None), foo.call('key, 'value)).limit(3), + (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..3d1f1801778c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOu import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -34,7 +41,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.as('y) val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } def assertJoin(sqlString: String, c: Class[_]): Any = { @@ -50,7 +57,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") } @@ -104,7 +111,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } test("inner join where, one match per row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 60701f0e154f8..df22ec9b1bcd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -24,6 +24,13 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" @@ -38,7 +45,7 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { test("programmatic ways of basic setting and getting") { clear() - assert(getAllConfs.size === 0) + assert(EQ(getAllConfs.size).===(0)) setConf(testKey, testVal) assert(getConf(testKey) == testVal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6bf439377aa3e..702714af5308d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -938,4 +938,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Seq(i))) } + + test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + (1 to 99).map(i => Seq(i))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ef9b76b1e251e..5d5c2b0b5168e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,18 +22,25 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("Simple UDF") { registerFunction("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(EQ(sql("SELECT strLenScala('test')").first().getInt(0)).===(4)) } test("TwoArgument UDF") { registerFunction("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(EQ(sql("SELECT strLenScala('test', 1)").first().getInt(0)).===(5)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 666235e57f812..3c1505fd63e43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,13 @@ import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -60,24 +67,32 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - test("register user type: MyDenseVector for MyLabeledPoint") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) + assert(EQ(labelsArrays.size).===(2)) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() - assert(featuresArrays.size === 2) + assert(EQ(featuresArrays.size).===(2)) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 9ba3c210171bd..a6f31e0e15302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,6 +22,13 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -107,8 +114,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(EQ(readBatches).===(expectedReadBatches), "Wrong number of read batches") + assert(EQ(readPartitions).===(expectedReadPartitions), "Wrong number of read partitions") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a5af71acfc79a..2c43c110e32b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -34,8 +41,8 @@ class PlannerSuite extends FunSuite { val logicalUnions = query collect { case u: logical.Union => u } val physicalUnions = planned collect { case u: execution.Union => u } - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) + assert(EQ(logicalUnions.size).===(2)) + assert(EQ(physicalUnions.size).===(1)) } test("count is partially aggregated") { @@ -43,7 +50,7 @@ class PlannerSuite extends FunSuite { val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggregations.size === 2) + assert(EQ(aggregations.size).===(2)) } test("count distinct is partially aggregated") { @@ -71,7 +78,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) @@ -91,7 +98,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1b11..0f53e9736fb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestRDDEntry(key: Int, value: String) case class NullReflectData( @@ -172,7 +179,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -188,7 +195,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -204,7 +211,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) + assert(EQ(actualCodec).===("UNCOMPRESSED" :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -220,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -236,7 +243,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -285,8 +292,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") + assert(EQ(result.size).===(9), "self-join result has incorrect size") + assert(EQ(result(0).size).===(12), "result row has incorrect size") result.zipWithIndex.foreach { case (row, index) => row.zipWithIndex.foreach { case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") @@ -296,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Import of simple Parquet file") { val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) + assert(EQ(result.size).===(15)) result.zipWithIndex.foreach { case (row, index) => { val checkBoolean = @@ -304,12 +311,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA row(0) == true else row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assert(EQ(checkBoolean).===(true), s"boolean field value in line $index did not match") + if (index % 5 == 0) assert(EQ(row(1)).===(5), s"int field value in line $index did not match") + assert(EQ(row(2)).===("abc"), s"string field value in line $index did not match") + assert(EQ(row(3)).===((index.toLong << 33)), s"long value in line $index did not match") + assert(EQ(row(4)).===(2.5F), s"float field value in line $index did not match") + assert(EQ(row(5)).===(4.5D), s"double field value in line $index did not match") } } } @@ -319,11 +326,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") + assert(EQ(row(0)).===(true), s"boolean field value in line $index did not match (every third row)") else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") + assert(EQ(row(0)).===(false), s"boolean field value in line $index did not match") + assert(EQ(row(1)).===((index.toLong << 33)), s"long field value in line $index did not match") + assert(EQ(row.size).===(2), s"number of columns in projection in line $index is incorrect") } } } @@ -381,8 +388,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + assert(EQ(rdd_copy(i).apply(0)).===(rdd_orig(i).key), s"key error in line $i") + assert(EQ(rdd_copy(i).apply(1)).===(rdd_orig(i).value), s"value error in line $i") } Utils.deleteRecursively(file) } @@ -396,11 +403,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() - assert(rdd_copy1.size === 100) + assert(EQ(rdd_copy1.size).===(100)) sql("INSERT INTO dest SELECT * FROM source") val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) - assert(rdd_copy2.size === 200) + assert(EQ(rdd_copy2.size).===(200)) Utils.deleteRecursively(dirname) } @@ -408,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) - assert(double_rdd.size === 30) + assert(EQ(double_rdd.size).===(30)) // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) @@ -425,7 +432,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -440,7 +447,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -478,11 +485,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val attribute2 = new AttributeReference("second", IntegerType, false)() val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) - assert(badfilter.isDefined === false) + assert(EQ(badfilter.isDefined).===(false)) val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) val badfilter2 = ParquetFilters.createFilter(predicate6) - assert(badfilter2.isDefined === false) + assert(EQ(badfilter2.isDefined).===(false)) } test("test filter by predicate pushdown") { @@ -492,21 +499,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) + assert(EQ(result1.size).===(50)) + assert(EQ(result1(0)(1)).===(100)) + assert(EQ(result1(49)(1)).===(149)) val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result2 = query2.collect() - assert(result2.size === 50) + assert(EQ(result2.size).===(50)) if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) + assert(EQ(result2(0)(1)).===(151)) + assert(EQ(result2(49)(1)).===(200)) } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) + assert(EQ(result2(0)(1)).===(150)) + assert(EQ(result2(49)(1)).===(199)) } } for(myval <- Seq("myint", "mylong")) { @@ -515,11 +522,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) + assert(EQ(result3.size).===(20)) + assert(EQ(result3(0)(1)).===(0)) + assert(EQ(result3(9)(1)).===(9)) + assert(EQ(result3(10)(1)).===(191)) + assert(EQ(result3(19)(1)).===(200)) } for(myval <- Seq("mydouble", "myfloat")) { val result4 = @@ -534,18 +541,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // currently no way to specify float constants in SqlParser? sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) + assert(EQ(result4.size).===(20)) + assert(EQ(result4(0)(1)).===(0)) + assert(EQ(result4(9)(1)).===(9)) + assert(EQ(result4(10)(1)).===(191)) + assert(EQ(result4(19)(1)).===(200)) } val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") assert( query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val booleanResult = query5.collect() - assert(booleanResult.size === 10) + assert(EQ(booleanResult.size).===(10)) for(i <- 0 until 10) { if (!booleanResult(i).getBoolean(0)) { fail(s"Boolean value in result row $i not true") @@ -559,16 +566,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val stringResult = query6.collect() - assert(stringResult.size === 1) + assert(EQ(stringResult.size).===(1)) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) + assert(EQ(stringResult(0).getInt(1)).===(100)) val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") assert( query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val optResult = query7.collect() - assert(optResult.size === 20) + assert(EQ(optResult.size).===(20)) for(i <- 0 until 20) { if (optResult(i)(7) != i * 2) { fail(s"optional Int value in result row $i should be ${2*4*i}") @@ -580,21 +587,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) + assert(EQ(result8.size).===(25)) + assert(EQ(result8(0)(7)).===(100)) + assert(EQ(result8(24)(7)).===(148)) val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result9 = query9.collect() - assert(result9.size === 25) + assert(EQ(result9.size).===(25)) if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) + assert(EQ(result9(0)(7)).===(152)) + assert(EQ(result9(24)(7)).===(200)) } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) + assert(EQ(result9(0)(7)).===(150)) + assert(EQ(result9(24)(7)).===(198)) } } val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") @@ -602,15 +609,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result10 = query10.collect() - assert(result10.size === 1) + assert(EQ(result10.size).===(1)) assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) + assert(EQ(result10(0).getInt(7)).===(100)) val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") assert( query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result11 = query11.collect() - assert(result11.size === 7) + assert(EQ(result11.size).===(7)) for(i <- 0 until 6) { if (!result11(i).getBoolean(6)) { fail(s"optional Boolean value in result row $i not true") @@ -623,7 +630,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) + assert(EQ(query.collect().size).===(10)) } test("Importing nested Parquet file (Addressbook)") { @@ -632,32 +639,32 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .collect() assert(result != null) - assert(result.size === 2) + assert(EQ(result.size).===(2)) val first_record = result(0) val second_record = result(1) assert(first_record != null) assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") + assert(EQ(first_record.size).===(3)) + assert(EQ(second_record(1)).===(null)) + assert(EQ(second_record(2)).===(null)) + assert(EQ(second_record(0)).===("A. Nonymous")) + assert(EQ(first_record(0)).===("Julien Le Dem")) val first_owner_numbers = first_record(1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] val first_contacts = first_record(2) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + assert(EQ(first_owner_numbers(0)).===("555 123 4567")) + assert(EQ(first_owner_numbers(2)).===("XXX XXX XXXX")) + assert(EQ(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size).===(2)) val first_contacts_entry_one = first_contacts(0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") + assert(EQ(first_contacts_entry_one(0)).===("Dmitriy Ryaboy")) + assert(EQ(first_contacts_entry_one(1)).===("555 987 6543")) val first_contacts_entry_two = first_contacts(1) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + assert(EQ(first_contacts_entry_two(0)).===("Chris Aniszczyk")) } test("Importing nested Parquet file (nested numbers)") { @@ -665,31 +672,31 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) + assert(EQ(result.size).===(1), "number of top-level rows incorrect") + assert(EQ(result(0).size).===(5), "number of fields in row incorrect") + assert(EQ(result(0)(0)).===(1)) + assert(EQ(result(0)(1)).===(7)) val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) + assert(EQ(subresult1.size).===(3)) + assert(EQ(subresult1(0)).===((1.toLong << 32))) + assert(EQ(subresult1(1)).===((1.toLong << 33))) + assert(EQ(subresult1(2)).===((1.toLong << 34))) val subresult2 = result(0)(3) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) + assert(EQ(subresult2.size).===(2)) + assert(EQ(subresult2(0)).===(2.5)) + assert(EQ(subresult2(1)).===(false)) val subresult3 = result(0)(4) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + assert(EQ(subresult3.size).===(2)) + assert(EQ(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(2)) val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + assert(EQ(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(1)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("Simple query on addressbook") { @@ -697,8 +704,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") + assert(EQ(tmp.size).===(1)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) } test("Projection in addressbook") { @@ -706,37 +713,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA data.registerTempTable("data") val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) + assert(EQ(tmp.size).===(2)) + assert(EQ(tmp(0).size).===(2)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) + assert(EQ(tmp(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmp(1)(0)).===("A. Nonymous")) + assert(EQ(tmp(1)(1)).===(null)) } test("Simple query on nested int data") { val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(2.5)) val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) + assert(EQ(subresult1.size).===(2)) + assert(EQ(subresult1(0)).===(2.5)) + assert(EQ(subresult1(1)).===(false)) val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) + assert(EQ(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("nested structs") { @@ -744,17 +751,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(false)) val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0).size).===(1)) + assert(EQ(result2(0)(0)).===(true)) val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0).size).===(1)) + assert(EQ(result3(0)(0)).===(false)) } test("simple map") { @@ -763,38 +770,38 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) + .getOrElse("key1", 0)).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) + .getOrElse("key2", 0)).===(2)) val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) + assert(EQ(result2(0)(0)).===(1)) } test("map with struct values") { val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) + assert(EQ(result1.size).===(1)) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0)(0)).===(42.toLong)) + assert(EQ(result2(0)(1)).===("the answer")) } test("Writing out Addressbook and reading it back in") { @@ -808,12 +815,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpcopy") val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) + assert(EQ(tmpdata.size).===(2)) + assert(EQ(tmpdata(0).size).===(2)) + assert(EQ(tmpdata(0)(0)).===("Julien Le Dem")) + assert(EQ(tmpdata(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmpdata(1)(0)).===("A. Nonymous")) + assert(EQ(tmpdata(1)(1)).===(null)) Utils.deleteRecursively(tmpdir) } @@ -826,26 +833,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpmapcopy") val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0)).===(2)) val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0)(0)).===(42.toLong)) + assert(EQ(result3(0)(1)).===("the answer")) Utils.deleteRecursively(tmpdir) } @@ -854,7 +861,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(tmpdir) createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") val result1 = sql("SELECT * FROM tmpemptytable").collect() - assert(result1.size === 0) + assert(EQ(result1.size).===(0)) Utils.deleteRecursively(tmpdir) } @@ -868,7 +875,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) - assert(a.dataType === b.dataType) + assert(EQ(a.dataType).===(b.dataType)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f025169ad5063..e88afaaf001c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -90,7 +90,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SerDe. */ private[spark] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 096b4a07aa2ea..0baf4c9f8c7ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -57,6 +57,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + def tableExists(db: Option[String], tableName: String): Boolean = { + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + client.getTable(databaseName, tblName, false) != null + } + def lookupRelation( db: Option[String], tableName: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 2fce414734579..3d24d87bc3d38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -71,7 +71,17 @@ case class CreateTableAsSelect( // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + if (sc.catalog.tableExists(Some(database), tableName)) { + if (allowExisting) { + // table already exists, will do nothing, to keep consistent with Hive + } else { + throw + new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + } + } else { + sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + } + Seq.empty[Row] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a90fc023e67d8..3565b377ef602 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class StatisticsSuite extends QueryTest with BeforeAndAfterAll { TestHive.reset() TestHive.cacheTables = false @@ -39,7 +46,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { case o => o } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail( s"""$analyzeCommand expected command: $c, but got ${operators(0)} @@ -81,11 +88,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // TODO: How does it works? needs to add it back for other hive version. if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable")).===(defaultSizeInBytes)) } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) + assert(EQ(queryTotalSize("analyzeTable")).===(BigInt(11624))) sql("DROP TABLE analyzeTable").collect() @@ -110,11 +117,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable_part")).===(defaultSizeInBytes)) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + assert(EQ(queryTotalSize("analyzeTable_part")).===(BigInt(17436))) sql("DROP TABLE analyzeTable_part").collect() @@ -131,7 +138,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(EQ(sizes.size).===(1), s"Size wrong for:\n ${rdd.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -151,14 +158,14 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold + assert(EQ(sizes.size).===(2) && sizes(0) <= autoBroadcastJoinThreshold && sizes(1) <= autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } - assert(bhj.size === 1, + assert(EQ(bhj.size).===(1), s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") checkAnswer(rdd, expectedAnswer) // check correctness of output @@ -172,7 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } - assert(shj.size === 1, + assert(EQ(shj.size).===(1), "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b897dff0159ff..3e9805dfe8067 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestData(a: Int, b: String) /** @@ -139,7 +146,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("CREATE TABLE AS runs once") { sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + assert(EQ(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0)).===(1), "Incorrect number of rows in created table") } @@ -161,7 +168,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(EQ(sql("SELECT 1").collect()).===(Array(Seq(1)))) setConf("spark.sql.dialect", "hiveql") } @@ -365,7 +372,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .toSet - assert(actual === expected) + assert(EQ(actual).===(expected)) } // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. @@ -415,7 +422,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .map(x => Pair(x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(EQ(results).===(Array(Pair("foo", 4)))) TestHive.reset() } @@ -557,8 +564,8 @@ class HiveQuerySuite extends HiveComparisonTest { sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => - assert(map.size === 1) - assert(map.head === (key, value)) + assert(EQ(map.size).===(1)) + assert(EQ(map.head).===((key, value))) } } @@ -654,7 +661,7 @@ class HiveQuerySuite extends HiveComparisonTest { sql("CREATE TABLE dp_verify(intcol INT)") sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + assert(EQ(sql("SELECT * FROM dp_verify").collect()).===(Array(Row(value)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 76a0ec01a6075..e9b1943ff8db7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -56,7 +56,7 @@ class SQLQuerySuite extends QueryTest { sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect - // expect the string => integer for field key cause the table ctas4 already existed. + // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect @@ -78,9 +78,14 @@ class SQLQuerySuite extends QueryTest { SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) + intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT CAST(key AS int) k, value FROM src ORDER BY k, value").collect().toSeq) + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) checkExistence(sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 6f57fe8958387..0069726426f96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -29,6 +29,13 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(convertToEqualizer(X).===(Y)) + * (This file already imports convertToEqualizer) + */ + case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -80,7 +87,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Simple column projection + filter on Parquet table") { val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() - assert(rdd.size === 5, "Filter returned incorrect number of rows") + assert(convertToEqualizer(rdd.size).===(5), "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } @@ -102,7 +109,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() val rddCopy = sql("SELECT * FROM ptable").collect() val rddOrig = sql("SELECT * FROM testsource").collect() - assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") + assert(convertToEqualizer(rddCopy.size).===(rddOrig.size), "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } @@ -111,7 +118,8 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft (rddOne, rddTwo).zipped.foreach { (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { case ((value_1, value_2), index) => - assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") + assert(convertToEqualizer(value_1).===(value_2), + s"table $tableName row $counter field ${fieldNames(index)} don't match") } counter = counter + 1 } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index bb47d373de63d..3e67161363e50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -17,15 +17,14 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.SparkException /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -40,9 +39,6 @@ import org.apache.spark.SparkException abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont def stop() {} - /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */ + /** + * Generates RDDs with blocks received by the receiver of this stream. */ override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] } - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array.empty)) - } - } + val blockRDD = { - /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + if (validTime < graph.startTime) { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // driver failure without any write ahead log to recover pre-failure data. + new BlockRDD[T](ssc.sc, Array.empty) + } else { + // Otherwise, ask the tracker for all the blocks that have been allocated to this stream + // for this batch + val blockInfos = + ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) + val blockStoreResults = blockInfos.map { _.blockStoreResult } + val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Check whether all the results are of the same type + val resultTypes = blockStoreResults.map { _.getClass }.distinct + if (resultTypes.size > 1) { + logWarning("Multiple result types in block information, WAL information will be ignored.") + } + + // If all the results are of type WriteAheadLogBasedStoreResult, then create + // WriteAheadLogBackedBlockRDD else create simple BlockRDD. + if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { + val logSegments = blockStoreResults.map { + _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + }.toArray + // Since storeInBlockManager = false, the storage level does not matter. + new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, + blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } + } + } + Some(blockRDD) } /** @@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 23295bf658712..dd1e96334952f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition( * If it does not find them, it looks up the corresponding file segment. * * @param sc SparkContext - * @param hadoopConfig Hadoop configuration * @param blockIds Ids of the blocks that contains this RDD's data * @param segments Segments in write ahead logs that contain this RDD's data * @param storeInBlockManager Whether to store in the block manager after reading from the segment @@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition( private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, - @transient hadoopConfig: Configuration, @transient blockIds: Array[BlockId], @transient segments: Array[WriteAheadLogFileSegment], storeInBlockManager: Boolean, @@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( s"the same as number of segments (${segments.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. + @transient private val hadoopConfig = sc.hadoopConfiguration private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) override def getPartitions: Array[Partition] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 7d73ada12d107..39b66e1130768 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Wait until all the received blocks in the network input tracker has // been consumed by network input DStreams, and jobs have been generated with them logInfo("Waiting for all received blocks to be consumed for job generation") - while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) { + while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) { Thread.sleep(pollTime) } logInfo("Waited for all received blocks to be consumed for job generation") @@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - Try(graph.generateJobs(time)) match { + // Set the SparkEnv in this thread, so that job generation code can access the environment + // Example: BlockRDDs are created in this thread, and it needs to access BlockManager + // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. + SparkEnv.set(ssc.env) + Try { + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch + graph.generateJobs(time) // generate jobs using allocated block + } match { case Success(jobs) => - val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => - val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) - (streamId, receivedBlockInfo) - }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + val receivedBlockInfos = + jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) + jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala new file mode 100644 index 0000000000000..5f5e1909908d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait representing any event in the ReceivedBlockTracker that updates its state. */ +private[streaming] sealed trait ReceivedBlockTrackerLogEvent + +private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchCleanupEvent(times: Seq[Time]) + extends ReceivedBlockTrackerLogEvent + + +/** Class representing the blocks of all the streams allocated to a batch */ +private[streaming] +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { + streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + } +} + +/** + * Class that keep track of all the received blocks, and allocate them to batches + * when required. All actions taken by this class can be saved to a write ahead log + * (if a checkpoint directory has been provided), so that the state of the tracker + * (received blocks and block-to-batch allocations) can be recovered after driver failure. + * + * Note that when any instance of this class is created with a checkpoint directory, + * it will try reading events from logs in the directory. + */ +private[streaming] class ReceivedBlockTracker( + conf: SparkConf, + hadoopConf: Configuration, + streamIds: Seq[Int], + clock: Clock, + checkpointDirOption: Option[String]) + extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + private var lastAllocatedBatchTime: Time = null + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block. This event will get written to the write ahead log (if enabled). */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAdditionEvent(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + true + } catch { + case e: Exception => + logError(s"Error adding block $receivedBlockInfo", e) + false + } + } + + /** + * Allocate all unallocated blocks to the given batch. + * This event will get written to the write ahead log (if enabled). + */ + def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { + if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { + val streamIdToBlocks = streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap + val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) + writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + timeToAllocatedBlocks(batchTime) = allocatedBlocks + lastAllocatedBatchTime = batchTime + allocatedBlocks + } else { + throw new SparkException(s"Unexpected allocation of blocks, " + + s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + } + } + + /** Get the blocks allocated to the given batch. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized { + timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty) + } + + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + timeToAllocatedBlocks.get(batchTime).map { + _.getBlocksOfStream(streamId) + }.getOrElse(Seq.empty) + } + } + + /** Check if any blocks are left to be allocated to batches. */ + def hasUnallocatedReceivedBlocks: Boolean = synchronized { + !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty) + } + + /** + * Get blocks that have been added but not yet allocated to any batch. This method + * is primarily used for testing. + */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Clean up block information of old batches. */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanupEvent(timesToCleanup)) + timeToAllocatedBlocks --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker. */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** + * Recover all the tracker actions from the write ahead logs to recover the state (unallocated + * and allocated block info) prior to failure. + */ + private def recoverFromWriteAheadLogs(): Unit = synchronized { + // Insert the recovered block information + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + // Insert the recovered block-to-batch allocations and clear the queue of received blocks + // (when the blocks were originally allocated to the batch, the queue must have been cleared). + def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + + s"${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + lastAllocatedBatchTime = batchTime + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + } + + // Cleanup the batch allocations + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches $batchTimes") + timeToAllocatedBlocks --= batchTimes + } + + logManagerOption.foreach { logManager => + logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + case BlockAdditionEvent(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocationEvent(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanupEvent(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + /** Get the queue of received blocks belonging to a particular stream */ + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d696563bcee83..1c3984d968d20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,15 +17,16 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} + +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import akka.actor._ -import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException} + +import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamIds = receiverInputStreams.map { _.id } + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreamIds, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + /** Allocate all unallocated blocks to the given batch. */ + def allocateBlocksToBatch(batchTime: Time): Unit = { + if (receiverInputStreams.nonEmpty) { + receivedBlockTracker.allocateBlocksToBatch(batchTime) + } + } + + /** Get the blocks for the given batch and all input streams. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = { + receivedBlockTracker.getBlocksOfBatch(batchTime) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) + } + } + + /** Clean up metadata older than the given threshold time */ + def cleanupOldMetadata(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, receiverActor: ActorRef, sender: ActorRef ) { - if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockStoreResult.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Check if any blocks are left to be processed */ - def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + def hasUnallocatedBlocks: Boolean = { + receivedBlockTracker.hasUnallocatedReceivedBlocks } /** Actor to receive messages from the receivers. */ @@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) - sender ! true + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { + SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") @@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6c8bb50145367..dbab685dc3511 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.SparkContext._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.language.existentials +import scala.reflect.ClassTag import util.ManualClock -import org.apache.spark.{SparkException, SparkConf} -import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import scala.collection.mutable +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f01b3a57..9efe15d01ed0c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -73,7 +73,8 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr)) + new NioBlockTransferService(conf, securityMgr), securityMgr) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 0000000000000..fd9c97f551c62 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogSuite._ +import org.apache.spark.util.Utils + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + // Verify added blocks are unallocated blocks + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Allocate the blocks to a batch and verify that all of them have been allocated + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + + // Allocate no blocks to another batch + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + + // Verify that batch 2 cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(2) + } + + // Verify that older batches cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(1) + } + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.allocateBlocksToBatch(batchTime1) + tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.allocateBlocksToBatch(batchTime2) + tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAdditionEvent) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond)) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + List.fill(5)(ReceivedBlockInfo(streamId, 0, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(logDir).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { + BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 10160244bcc91..d2b983c4b4d1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { ) // Create the RDD and verify whether the returned data is correct - val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { - val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7ee4b5c842df1..7023a1170654f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.util.JavaUtils @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -90,6 +91,22 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b5a92d87d722..fdd3c2300fa78 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -89,6 +90,22 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) }