Skip to content

Commit

Permalink
apache#10 performance issue when gc blocks [follow up]
Browse files Browse the repository at this point in the history
  • Loading branch information
hn5092 committed Dec 27, 2018
1 parent 4af8848 commit f3f6174
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 36 deletions.
64 changes: 39 additions & 25 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark

import java.lang.ref.{ReferenceQueue, WeakReference}
import java.util.Collections
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ExecutorService, ScheduledExecutorService, TimeUnit}

import scala.collection.JavaConverters._

Expand All @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils}


/**
* Classes that represent cleaning tasks.
*/
Expand Down Expand Up @@ -112,6 +113,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private val blockOnShuffleCleanupTasks = sc.conf.getBoolean(
"spark.cleaner.referenceTracking.blocking.shuffle", false)

/**
* The cleaning thread size.
*/
private val cleanupTaskThreads = sc.conf.getInt(
"spark.cleaner.referenceTracking.cleanupThreadNumber", 100)

private val cleanupExecutorPool: ExecutorService =
ThreadUtils.newDaemonFixedThreadPool(cleanupTaskThreads, "cleanup")

@volatile private var stopped = false

/** Attach a listener object to get information of when objects are cleaned. */
Expand Down Expand Up @@ -177,33 +187,37 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Keep cleaning RDD, shuffle, and broadcast state. */
private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) {
while (!stopped) {
try {
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[CleanupTaskWeakReference])
// Synchronize here to avoid being interrupted on stop()
synchronized {
reference.foreach { ref =>
logDebug("Got cleaning task " + ref.task)
referenceBuffer.remove(ref)
ref.task match {
case CleanRDD(rddId) =>
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
case CleanShuffle(shuffleId) =>
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
case CleanBroadcast(broadcastId) =>
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId)
}
Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[CleanupTaskWeakReference]).foreach {
r =>
referenceBuffer.remove(r)
runtCleanTask(r)
}
}
}

private def runtCleanTask(ref: CleanupTaskWeakReference) = {
cleanupExecutorPool.submit(new Runnable {
override def run(): Unit = {
try {
ref.task match {
case CleanRDD(rddId) =>
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
case CleanShuffle(shuffleId) =>
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
case CleanBroadcast(broadcastId) =>
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId)
}
} catch {
case ie: InterruptedException if stopped => // ignore
case e: Exception => logError("Error in cleaning thread", e)
}
} catch {
case ie: InterruptedException if stopped => // ignore
case e: Exception => logError("Error in cleaning thread", e)
}
}
})
}

/** Perform RDD cleanup. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ private[spark] object MapOutputTracker extends Logging {
if (arr.length >= minBroadcastSize) {
// Use broadcast instead.
// Important arr(0) is the tag == DIRECT, ignore that while deserializing !
val bcast = broadcastManager.newBroadcast(arr, isLocal)
val bcast = broadcastManager.newBroadcast(arr, isLocal, null)
// toByteArray creates copy, so we can reuse out
out.reset()
out.write(BROADCAST)
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1486,10 +1486,13 @@ class SparkContext(config: SparkConf) extends Logging {
assertNotStopped()
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val executionId = getLocalProperty("spark.sql.execution.id")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal, executionId)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
if (executionId == null) {
cleaner.foreach(_.registerBroadcastForCleanup(bc))
}
bc
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ package org.apache.spark.broadcast

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

import avro.shaded.com.google.common.collect.Maps
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging


private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
Expand All @@ -34,6 +37,7 @@ private[spark] class BroadcastManager(

private var initialized = false
private var broadcastFactory: BroadcastFactory = null
var cachedBroadcast = Maps.newConcurrentMap[String, ListBuffer[Long]]()

initialize()

Expand All @@ -54,12 +58,31 @@ private[spark] class BroadcastManager(

private val nextBroadcastId = new AtomicLong(0)

private[spark] def currentBroadcastId: Long = nextBroadcastId.get()

private[broadcast] val cachedValues = {
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
}

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def cleanBroadCast(executionId: String): Unit = {
if (cachedBroadcast.containsKey(executionId)) {
cachedBroadcast.get(executionId).foreach(broadcastId => unbroadcast(broadcastId, true, false))
cachedBroadcast.remove(executionId)
}
}

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, executionId: String): Broadcast[T] = {
val broadcastId = nextBroadcastId.getAndIncrement()
if (executionId != null) {
if (cachedBroadcast.containsKey(executionId)) {
cachedBroadcast.get(executionId) += broadcastId
} else {
val list = new scala.collection.mutable.ListBuffer[Long]
list += broadcastId
cachedBroadcast.put(executionId, list)
}
}
broadcastFactory.newBroadcast[T](value_, isLocal, broadcastId)
}

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import org.apache.spark.rdd.RDD
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
val taskBinary: Broadcast[Array[Byte]],
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ import org.apache.spark.shuffle.ShuffleWriter
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
val taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
localProperties: Properties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ private[spark] class TaskSetManager(
private def maybeFinishTaskSet() {
if (isZombie && runningTasks == 0) {
sched.taskSetFinished(this)
val broadcastId = taskSet.tasks.head match {
case resultTask: ResultTask[Any, Any] =>
resultTask.taskBinary.id
case shuffleMapTask: ShuffleMapTask =>
shuffleMapTask.taskBinary.id
}
SparkEnv.get.broadcastManager.unbroadcast(broadcastId, true, false)
if (tasksSuccessful == numTasks) {
blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
taskSet.stageId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import scala.util.Random
import scala.util.control.NonFatal

import com.codahale.metrics.{MetricRegistry, MetricSet}
import com.google.common.io.CountingOutputStream

import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ class BlockManagerMasterEndpoint(
0 // zero blocks were removed
}
}.toSeq

val blocksToRemove = blockLocations.keySet().asScala
.collect {
case broadcastId@BroadcastBlockId(`broadcastId`, _) =>
broadcastId
}
blocksToRemove.foreach(blockLocations.remove)
Future.sequence(futures)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
import org.apache.spark.SparkEnv
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}

Expand Down Expand Up @@ -84,6 +84,7 @@ object SQLExecution {
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
SparkEnv.get.broadcastManager.cleanBroadCast(executionId.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class ParquetFileFormat
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
false
true
}

override def buildReaderWithPartitionValues(
Expand Down

0 comments on commit f3f6174

Please sign in to comment.