Skip to content

Commit

Permalink
SPARK-7436: Fixed instantiation of custom recovery mode factory and a…
Browse files Browse the repository at this point in the history
…dded tests
  • Loading branch information
jacek-lewandowski committed May 7, 2015
1 parent 4f87e95 commit ff0a3c2
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ private[master] class Master(
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(conf.getClass, Serialization.getClass)
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
.newInstance(conf, SerializationExtension(context.system))
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// This file is placed in different package to make sure all of these components work well
// when they are outside of org.apache.spark.
package other.supplier

import scala.collection.mutable
import scala.reflect.ClassTag

import akka.serialization.Serialization

import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._

class CustomRecoveryModeFactory(
conf: SparkConf,
serialization: Serialization
) extends StandaloneRecoveryModeFactory(conf, serialization) {

CustomRecoveryModeFactory.instantiationAttempts += 1

/**
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
* is handled for recovery.
*
*/
override def createPersistenceEngine(): PersistenceEngine =
new CustomPersistenceEngine(serialization)

/**
* Create an instance of LeaderAgent that decides who gets elected as master.
*/
override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent =
new CustomLeaderElectionAgent(master)
}

object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}

class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()

CustomPersistenceEngine.lastInstance = Some(this)

/**
* Defines how the object is serialized and persisted. Implementation will
* depend on the store used.
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
serialization.serialize(obj) match {
case util.Success(bytes) => data += name -> bytes
case util.Failure(cause) => throw new RuntimeException(cause)
}
}

/**
* Defines how the object referred by its name is removed from the store.
*/
override def unpersist(name: String): Unit = {
CustomPersistenceEngine.unpersistAttempts += 1
data -= name
}

/**
* Gives all objects, matching a prefix. This defines how objects are
* read/deserialized back.
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
yield serialization.deserialize(bytes, clazz)

results.find(_.isFailure).foreach {
case util.Failure(cause) => throw new RuntimeException(cause)
}

results.flatMap(_.toOption).toSeq
}
}

object CustomPersistenceEngine {
@volatile var persistAttempts = 0
@volatile var unpersistAttempts = 0
@volatile var readAttempts = 0

@volatile var lastInstance: Option[CustomPersistenceEngine] = None
}

class CustomLeaderElectionAgent(val masterActor: LeaderElectable) extends LeaderElectionAgent {
masterActor.electedLeader()
}

100 changes: 97 additions & 3 deletions core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@

package org.apache.spark.deploy.master

import java.util.Date

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps

import akka.actor.Address
import org.scalatest.FunSuite
import org.scalatest.{FunSuite, Matchers}
import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}

import org.apache.spark.{SSLOptions, SparkConf, SparkException}
import org.apache.spark.deploy._
import org.apache.spark.{SparkConf, SparkException}

class MasterSuite extends FunSuite {
class MasterSuite extends FunSuite with Matchers {

test("toAkkaUrl") {
val conf = new SparkConf(loadDefaults = false)
Expand Down Expand Up @@ -63,4 +71,90 @@ class MasterSuite extends FunSuite {
}
assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
}

test("can use a custom recovery mode factory") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.deploy.recoveryMode", "CUSTOM")
conf.set("spark.deploy.recoveryMode.factory",
classOf[CustomRecoveryModeFactory].getCanonicalName)

val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts

val commandToPersist = new Command(
mainClass = "",
arguments = Nil,
environment = Map.empty,
classPathEntries = Nil,
libraryPathEntries = Nil,
javaOpts = Nil
)

val appToPersist = new ApplicationInfo(
startTime = 0,
id = "test_app",
desc = new ApplicationDescription(
name = "",
maxCores = None,
memoryPerExecutorMB = 0,
command = commandToPersist,
appUiUrl = "",
eventLogDir = None,
eventLogCodec = None,
coresPerExecutor = None),
submitDate = new Date(),
driver = null,
defaultCores = 0
)

val driverToPersist = new DriverInfo(
startTime = 0,
id = "test_driver",
desc = new DriverDescription(
jarUrl = "",
mem = 0,
cores = 0,
supervise = false,
command = commandToPersist
),
submitDate = new Date()
)

val workerToPersist = new WorkerInfo(
id = "test_worker",
host = "127.0.0.1",
port = 10000,
cores = 0,
memory = 0,
actor = null,
webUiPort = 0,
publicAddress = ""
)

val (actorSystem, port, uiPort, restPort) =
Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf)

try {
Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds)

CustomPersistenceEngine.lastInstance.isDefined shouldBe true
val persistenceEngine = CustomPersistenceEngine.lastInstance.get

persistenceEngine.addApplication(appToPersist)
persistenceEngine.addDriver(driverToPersist)
persistenceEngine.addWorker(workerToPersist)

val (apps, drivers, workers) = persistenceEngine.readPersistedData()

apps.map(_.id) should contain(appToPersist.id)
drivers.map(_.id) should contain(driverToPersist.id)
workers.map(_.id) should contain(workerToPersist.id)

} finally {
actorSystem.shutdown()
actorSystem.awaitTermination()
}

CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts
}

}

0 comments on commit ff0a3c2

Please sign in to comment.