Skip to content

Commit

Permalink
Merge branch 'master' into profiler
Browse files Browse the repository at this point in the history
Conflicts:
	docs/configuration.md
  • Loading branch information
davies committed Sep 14, 2014
2 parents c23865c + 2aea0da commit 09d02c3
Show file tree
Hide file tree
Showing 44 changed files with 924 additions and 496 deletions.
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}

private[spark]
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
}
}
}

object SparkEnv extends Logging {
Expand Down
18 changes: 15 additions & 3 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.TaskCompletionListener
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}


/**
Expand All @@ -41,7 +41,7 @@ class TaskContext(
val attemptId: Long,
val runningLocally: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends Serializable {
extends Serializable with Logging {

@deprecated("use partitionId", "0.8.1")
def splitId = partitionId
Expand Down Expand Up @@ -103,8 +103,20 @@ class TaskContext(
/** Marks the task as completed and triggers the listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
val errorMsgs = new ArrayBuffer[String](2)
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
onCompleteCallbacks.reverse.foreach { listener =>
try {
listener.onTaskCompletion(this)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError("Error in TaskCompletionListener", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs)
}
}

/** Marks the task for interruption, i.e. cancellation. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.api.java

import java.io.Closeable
import java.util
import java.util.{Map => JMap}

Expand All @@ -40,7 +41,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
*/
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
class JavaSparkContext(val sc: SparkContext)
extends JavaSparkContextVarargsWorkaround with Closeable {

/**
* Create a JavaSparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
Expand Down Expand Up @@ -534,6 +537,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.stop()
}

override def close(): Unit = stop()

/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
Expand Down
58 changes: 46 additions & 12 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
Expand Down Expand Up @@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

override def getPartitions = parent.partitions

Expand All @@ -63,19 +65,26 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars += ("SPARK_REUSE_WORKER" -> "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()

// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
if (reuse_worker && complete_cleanly) {
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
} else {
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}

Expand Down Expand Up @@ -133,6 +142,7 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
complete_cleanly = true
null
}
} catch {
Expand Down Expand Up @@ -195,29 +205,45 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- oldBids) {
if (!newBids.contains(bid)) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
}
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
oldBids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
worker.shutdownOutput()

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
} finally {
Try(worker.shutdownOutput()) // kill Python worker process
worker.shutdownOutput()
}
}
}
Expand Down Expand Up @@ -278,6 +304,14 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
private def getWorkerBroadcasts(worker: Socket) = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
}
}

/**
* Adapter for calling SparkContext#runJob from Python.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
val idleWorkers = new mutable.Queue[Socket]()
var lastActivity = 0L
new MonitorThread().start()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

Expand All @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

def create(): Socket = {
if (useDaemon) {
synchronized {
if (idleWorkers.size > 0) {
return idleWorkers.dequeue()
}
}
createThroughDaemon()
} else {
createSimpleWorker()
Expand Down Expand Up @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}

/**
* Monitor all the idle workers, kill them after timeout.
*/
private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {

setDaemon(true)

override def run() {
while (true) {
synchronized {
if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
cleanupIdleWorkers()
lastActivity = System.currentTimeMillis()
}
}
Thread.sleep(10000)
}
}
}

private def cleanupIdleWorkers() {
while (idleWorkers.length > 0) {
val worker = idleWorkers.dequeue()
try {
// the worker will exit after closing the socket
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}

private def stopDaemon() {
synchronized {
if (useDaemon) {
cleanupIdleWorkers()

// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
Expand All @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}

def stopWorker(worker: Socket) {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
synchronized {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
}
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}

def releaseWorker(worker: Socket) {
if (useDaemon) {
synchronized {
lastActivity = System.currentTimeMillis()
idleWorkers.enqueue(worker)
}
} else {
// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}
}

private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

/**
* Exception thrown when there is an exception in
* executing the callback in TaskCompletionListener.
*/
private[spark]
class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {

override def getMessage: String = {
if (errorMessages.size == 1) {
errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
}
}
}
Loading

0 comments on commit 09d02c3

Please sign in to comment.