Skip to content

Commit

Permalink
SPARK-2269 Refactor mesos scheduler resourceOffers and add unit test
Browse files Browse the repository at this point in the history
Author: Timothy Chen <tnachen@gmail.com>

Closes #1487 from tnachen/resource_offer_refactor and squashes the following commits:

4ea5dec [Timothy Chen] Rebase from master and address comments
9ccab09 [Timothy Chen] Address review comments
e6494dc [Timothy Chen] Refactor class loading
8207428 [Timothy Chen] Refactor mesos scheduler resourceOffers and add unit test
  • Loading branch information
tnachen authored and Andrew Or committed Nov 11, 2014
1 parent 7f37188 commit a878660
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,16 @@ private[spark] class MesosSchedulerBackend(
execArgs
}

private def setClassLoader(): ClassLoader = {
val oldClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(classLoader)
oldClassLoader
}

private def restoreClassLoader(oldClassLoader: ClassLoader) {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}

override def offerRescinded(d: SchedulerDriver, o: OfferID) {}

override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
val oldClassLoader = setClassLoader()
try {
inClassLoader() {
appId = frameworkId.getValue
logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
}
} finally {
restoreClassLoader(oldClassLoader)
}
}

Expand All @@ -200,6 +187,16 @@ private[spark] class MesosSchedulerBackend(
}
}

private def inClassLoader()(fun: => Unit) = {
val oldClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(classLoader)
try {
fun
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}

override def disconnected(d: SchedulerDriver) {}

override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
Expand All @@ -210,66 +207,57 @@ private[spark] class MesosSchedulerBackend(
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
val oldClassLoader = setClassLoader()
try {
synchronized {
// Build a big list of the offerable workers, and remember their indices so that we can
// figure out which Offer to reply to for each worker
val offerableWorkers = new ArrayBuffer[WorkerOffer]
val offerableIndices = new HashMap[String, Int]

def sufficientOffer(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
(mem >= MemoryUtils.calculateTotalMemory(sc) &&
// need at least 1 for executor, 1 for task
cpus >= 2 * scheduler.CPUS_PER_TASK) ||
(slaveIdsWithExecutors.contains(slaveId) &&
cpus >= scheduler.CPUS_PER_TASK)
}
inClassLoader() {
val (acceptedOffers, declinedOffers) = offers.partition { o =>
val mem = getResource(o.getResourcesList, "mem")
val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
(mem >= MemoryUtils.calculateTotalMemory(sc) &&
// need at least 1 for executor, 1 for task
cpus >= 2 * scheduler.CPUS_PER_TASK) ||
(slaveIdsWithExecutors.contains(slaveId) &&
cpus >= scheduler.CPUS_PER_TASK)
}

for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
val slaveId = offer.getSlaveId.getValue
offerableIndices.put(slaveId, index)
val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
getResource(offer.getResourcesList, "cpus").toInt
} else {
// If the executor doesn't exist yet, subtract CPU for executor
getResource(offer.getResourcesList, "cpus").toInt -
scheduler.CPUS_PER_TASK
}
offerableWorkers += new WorkerOffer(
offer.getSlaveId.getValue,
offer.getHostname,
cpus)
val offerableWorkers = acceptedOffers.map { o =>
val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
getResource(o.getResourcesList, "cpus").toInt
} else {
// If the executor doesn't exist yet, subtract CPU for executor
getResource(o.getResourcesList, "cpus").toInt -
scheduler.CPUS_PER_TASK
}
new WorkerOffer(
o.getSlaveId.getValue,
o.getHostname,
cpus)
}

// Call into the TaskSchedulerImpl
val taskLists = scheduler.resourceOffers(offerableWorkers)

// Build a list of Mesos tasks for each slave
val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]())
for ((taskList, index) <- taskLists.zipWithIndex) {
if (!taskList.isEmpty) {
for (taskDesc <- taskList) {
val slaveId = taskDesc.executorId
val offerNum = offerableIndices(slaveId)
slaveIdsWithExecutors += slaveId
taskIdToSlaveId(taskDesc.taskId) = slaveId
mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId))
}
val slaveIdToOffer = acceptedOffers.map(o => o.getSlaveId.getValue -> o).toMap

val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]

// Call into the TaskSchedulerImpl
scheduler.resourceOffers(offerableWorkers)
.filter(!_.isEmpty)
.foreach { offer =>
offer.foreach { taskDesc =>
val slaveId = taskDesc.executorId
slaveIdsWithExecutors += slaveId
taskIdToSlaveId(taskDesc.taskId) = slaveId
mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
.add(createMesosTask(taskDesc, slaveId))
}
}

// Reply to the offers
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
for (i <- 0 until offers.size) {
d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters)
}
// Reply to the offers
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?

mesosTasks.foreach { case (slaveId, tasks) =>
d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
} finally {
restoreClassLoader(oldClassLoader)

declinedOffers.foreach(o => d.declineOffer(o.getId))
}
}

Expand Down Expand Up @@ -308,8 +296,7 @@ private[spark] class MesosSchedulerBackend(
}

override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
val oldClassLoader = setClassLoader()
try {
inClassLoader() {
val tid = status.getTaskId.getValue.toLong
val state = TaskState.fromMesos(status.getState)
synchronized {
Expand All @@ -322,18 +309,13 @@ private[spark] class MesosSchedulerBackend(
}
}
scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
} finally {
restoreClassLoader(oldClassLoader)
}
}

override def error(d: SchedulerDriver, message: String) {
val oldClassLoader = setClassLoader()
try {
inClassLoader() {
logError("Mesos error: " + message)
scheduler.error(message)
} finally {
restoreClassLoader(oldClassLoader)
}
}

Expand All @@ -350,15 +332,12 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}

private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
val oldClassLoader = setClassLoader()
try {
inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
scheduler.executorLost(slaveId.getValue, reason)
} finally {
restoreClassLoader(oldClassLoader)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler.mesos

import org.scalatest.FunSuite
import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext}
import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend}
import org.apache.mesos.SchedulerDriver
import org.apache.mesos.Protos._
import org.scalatest.mock.EasyMockSugar
import org.apache.mesos.Protos.Value.Scalar
import org.easymock.{Capture, EasyMock}
import java.nio.ByteBuffer
import java.util.Collections
import java.util
import scala.collection.mutable

class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar {
test("mesos resource offer is launching tasks") {
def createOffer(id: Int, mem: Int, cpu: Int) = {
val builder = Offer.newBuilder()
builder.addResourcesBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
.setScalar(Scalar.newBuilder().setValue(mem))
builder.addResourcesBuilder()
.setName("cpus")
.setType(Value.Type.SCALAR)
.setScalar(Scalar.newBuilder().setValue(cpu))
builder.setId(OfferID.newBuilder().setValue(id.toString).build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1"))
.setSlaveId(SlaveID.newBuilder().setValue("s1")).setHostname("localhost").build()
}

val driver = EasyMock.createMock(classOf[SchedulerDriver])
val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl])

val sc = EasyMock.createMock(classOf[SparkContext])

EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes()
EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes()
EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes()
EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes()
EasyMock.replay(sc)
val minMem = MemoryUtils.calculateTotalMemory(sc).toInt
val minCpu = 4
val offers = new java.util.ArrayList[Offer]
offers.add(createOffer(1, minMem, minCpu))
offers.add(createOffer(1, minMem - 1, minCpu))
val backend = new MesosSchedulerBackend(taskScheduler, sc, "master")
val workerOffers = Seq(offers.get(0)).map(o => new WorkerOffer(
o.getSlaveId.getValue,
o.getHostname,
2
))
val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(workerOffers))).andReturn(Seq(Seq(taskDesc)))
EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes()
EasyMock.replay(taskScheduler)
val capture = new Capture[util.Collection[TaskInfo]]
EasyMock.expect(
driver.launchTasks(
EasyMock.eq(Collections.singleton(offers.get(0).getId)),
EasyMock.capture(capture),
EasyMock.anyObject(classOf[Filters])
)
).andReturn(Status.valueOf(1))
EasyMock.expect(driver.declineOffer(offers.get(1).getId)).andReturn(Status.valueOf(1))
EasyMock.replay(driver)
backend.resourceOffers(driver, offers)
assert(capture.getValue.size() == 1)
val taskInfo = capture.getValue.iterator().next()
assert(taskInfo.getName.equals("n1"))
val cpus = taskInfo.getResourcesList.get(0)
assert(cpus.getName.equals("cpus"))
assert(cpus.getScalar.getValue.equals(2.0))
assert(taskInfo.getSlaveId.getValue.equals("s1"))
}
}

0 comments on commit a878660

Please sign in to comment.