Skip to content

Commit

Permalink
Made changes based on PR comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Sep 23, 2014
1 parent 390b45d commit dc54c71
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
15 changes: 7 additions & 8 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -865,26 +865,25 @@ private[spark] object Utils extends Logging {
}

/** Default filtering function for finding call sites using `getCallSite`. */
private def defaultCallSiteFilterFunc(className: String): Boolean = {
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
// finding the call site of a method.
val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
val SCALA_CLASS_REGEX = """^scala""".r
val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
// If the class neither belongs to Spark nor is a simple Scala class, then it is a
// user-defined class
!isSparkClass && !isScalaClass
// If the class is a Spark internal class or a Scala class, then exclude.
isSparkCoreClass || isScalaClass
}

/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*
* @param classFilterFunc Function that returns true if the given class belongs to user code
* @param skipClass Function that is used to exclude non-user-code classes.
*/
def getCallSite(classFilterFunc: String => Boolean = defaultCallSiteFilterFunc): CallSite = {
def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
val trace = Thread.currentThread.getStackTrace()
.filterNot { ste:StackTraceElement =>
// When running under some profilers, the current stack trace might contain some bogus
Expand All @@ -905,7 +904,7 @@ private[spark] object Utils extends Logging {

for (el <- trace) {
if (insideSpark) {
if (!classFilterFunc(el.getClassName)) {
if (skipClass(el.getClassName)) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class StreamingContext private[streaming] (
throw new SparkException("StreamingContext has already been stopped")
}
validate()
sparkContext.setCallSite(DStream.getCallSite())
sparkContext.setCallSite(DStream.getCreationSite())
scheduler.start()
state = Started
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ abstract class DStream[T: ClassTag] (
def context = ssc

/* Set the creation call site */
private[streaming] val creationSite = DStream.getCallSite()
private[streaming] val creationSite = DStream.getCreationSite()

/** Persist the RDDs of this DStream with the given storage level */
def persist(level: StorageLevel): DStream[T] = {
Expand Down Expand Up @@ -805,25 +805,25 @@ abstract class DStream[T: ClassTag] (
private[streaming] object DStream {

/** Get the creation site of a DStream from the stack trace of when the DStream is created. */
def getCallSite(): CallSite = {
def getCreationSite(): CallSite = {
val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r
val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r
val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r
val SCALA_CLASS_REGEX = """^scala""".r

/** Filtering function that returns true for classes that belong to a streaming application */
def streamingClassFilterFunc(className: String): Boolean = {
/** Filtering function that excludes non-user classes for a streaming application */
def streamingExclustionFunction(className: String): Boolean = {
def doesMatch(r: Regex) = r.findFirstIn(className).isDefined
val isSparkClass = doesMatch(SPARK_CLASS_REGEX)
val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX)
val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX)
val isScalaClass = doesMatch(SCALA_CLASS_REGEX)

// If the class is a spark example class or a streaming test class then it is considered
// as a streaming application class. Otherwise, consider any non-Spark and non-Scala class
// as streaming application class.
isSparkExampleClass || isSparkStreamingTestClass || !(isSparkClass || isScalaClass)
// as a streaming application class and don't exclude. Otherwise, exclude any
// non-Spark and non-Scala class, as the rest would streaming application classes.
(isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass
}
org.apache.spark.util.Utils.getCallSite(streamingClassFilterFunc)
org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction)
}
}

0 comments on commit dc54c71

Please sign in to comment.