Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1244: Throw exception if map output status exceeds frame size #152

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,28 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

def receive = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."

/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
* will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
throw exception
}
sender ! mapOutputStatuses

case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object SparkEnv extends Logging {
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
import org.apache.spark.util.{AkkaUtils, Utils}

/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
Expand Down Expand Up @@ -120,9 +120,7 @@ private[spark] class Executor(

// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[spark] object AkkaUtils extends Logging {

val akkaTimeout = conf.getInt("spark.akka.timeout", 100)

val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
val akkaFrameSize = maxFrameSizeBytes(conf)
val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
if (!akkaLogLifecycleEvents) {
Expand Down Expand Up @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging {
|akka.remote.netty.tcp.port = $port
|akka.remote.netty.tcp.tcp-nodelay = on
|akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
|akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
|akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B
|akka.remote.netty.tcp.execution-pool-size = $akkaThreads
|akka.actor.default-dispatcher.throughput = $akkaBatchSize
|akka.log-config-on-start = $logAkkaConfig
Expand Down Expand Up @@ -121,4 +121,9 @@ private[spark] object AkkaUtils extends Logging {
def lookupTimeout(conf: SparkConf): FiniteDuration = {
Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
}

/** Returns the configured max frame size for Akka messages in bytes. */
def maxFrameSizeBytes(conf: SparkConf): Int = {
conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024
}
}
10 changes: 5 additions & 5 deletions core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "true")
badconf.set("spark.authenticate.secret", "bad")
val securityManagerBad = new SecurityManager(badconf);
val securityManagerBad = new SecurityManager(badconf)

assert(securityManagerBad.isAuthenticationEnabled() === true)

Expand Down Expand Up @@ -84,7 +84,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "false")
Expand Down Expand Up @@ -136,7 +136,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val goodconf = new SparkConf
goodconf.set("spark.authenticate", "true")
Expand Down Expand Up @@ -189,7 +189,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "false")
Expand Down
58 changes: 52 additions & 6 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import scala.concurrent.Await

import akka.actor._
import akka.testkit.TestActorRef
import org.scalatest.FunSuite

import org.apache.spark.scheduler.MapStatus
Expand Down Expand Up @@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.stop()
}

test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -100,11 +104,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val hostname = "localhost"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext

// Will be cleared by LocalSparkContext
System.setProperty("spark.driver.port", boundPort.toString)

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
Expand All @@ -126,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))

masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
masterTracker.incrementEpoch()
Expand All @@ -136,4 +142,44 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}

test("remote fetch below akka frame size") {
val newConf = new SparkConf
newConf.set("spark.akka.frameSize", "1")
newConf.set("spark.akka.askTimeout", "1") // Fail fast

val masterTracker = new MapOutputTrackerMaster(conf)
val actorSystem = ActorSystem("test")
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
val masterActor = actorRef.underlyingActor

// Frame size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
masterActor.receive(GetMapOutputStatuses(10))
}

test("remote fetch exceeds akka frame size") {
val newConf = new SparkConf
newConf.set("spark.akka.frameSize", "1")
newConf.set("spark.akka.askTimeout", "1") // Fail fast

val masterTracker = new MapOutputTrackerMaster(conf)
val actorSystem = ActorSystem("test")
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
val masterActor = actorRef.underlyingActor

// Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
// Note that the size is hand-selected here because map output statuses are compressed before
// being sent.
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new MapStatus(
BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
}
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
}
}