Skip to content

Commit

Permalink
[SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName …
Browse files Browse the repository at this point in the history
…that is not safe in scala

## What changes were proposed in this pull request?

When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest scalatest/scalatest#1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425.

To fix this issue, we follow the solution in scalatest/scalatest#1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName.

## How was this patch tested?
added unit test

Author: Fangshi Li <fli@linkedin.com>

Closes #21276 from fangshil/SPARK-24216.
  • Loading branch information
Fangshi Li authored and cloud-fan committed Jun 12, 2018
1 parent f0ef1b3 commit cc88d7f
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 6 deletions.
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
}

override def toString: String = {
// getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
if (metadata == null) {
"Un-registered Accumulator: " + getClass.getSimpleName
"Un-registered Accumulator: " + Utils.getSimpleName(getClass)
} else {
getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
}
}
}
Expand Down
59 changes: 58 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.util

import java.io._
import java.lang.{Byte => JByte}
import java.lang.InternalError
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
Expand Down Expand Up @@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging {

/** Return the class name of the given object, removing all dollar signs */
def getFormattedClassName(obj: AnyRef): String = {
obj.getClass.getSimpleName.replace("$", "")
getSimpleName(obj.getClass).replace("$", "")
}

/**
Expand Down Expand Up @@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging {
HashCodes.fromBytes(secretBytes).toString()
}

/**
* Safer than Class obj's getSimpleName which may throw Malformed class name error in scala.
* This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
*/
def getSimpleName(cls: Class[_]): String = {
try {
return cls.getSimpleName
} catch {
case err: InternalError => return stripDollars(stripPackages(cls.getName))
}
}

/**
* Remove the packages from full qualified class name
*/
private def stripPackages(fullyQualifiedName: String): String = {
fullyQualifiedName.split("\\.").takeRight(1)(0)
}

/**
* Remove trailing dollar signs from qualified class name,
* and return the trailing part after the last dollar sign in the middle
*/
private def stripDollars(s: String): String = {
val lastDollarIndex = s.lastIndexOf('$')
if (lastDollarIndex < s.length - 1) {
// The last char is not a dollar sign
if (lastDollarIndex == -1 || !s.contains("$iw")) {
// The name does not have dollar sign or is not an intepreter
// generated class, so we should return the full string
s
} else {
// The class name is intepreter generated,
// return the part after the last dollar sign
// This is the same behavior as getClass.getSimpleName
s.substring(lastDollarIndex + 1)
}
}
else {
// The last char is a dollar sign
// Find last non-dollar char
val lastNonDollarChar = s.reverse.find(_ != '$')
lastNonDollarChar match {
case None => s
case Some(c) =>
val lastNonDollarIndex = s.lastIndexOf(c)
if (lastNonDollarIndex == -1) {
s
} else {
// Strip the trailing dollar signs
// Invoke stripDollars again to get the simple name
stripDollars(s.substring(0, lastNonDollarIndex + 1))
}
}
}
}
}

private[util] object CallerContext extends Logging {
Expand Down
16 changes: 16 additions & 0 deletions core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
}
}

object MalformedClassObject {
class MalformedClass
}

test("Safe getSimpleName") {
// getSimpleName on class of MalformedClass will result in error: Malformed class name
// Utils.getSimpleName works
val err = intercept[java.lang.InternalError] {
classOf[MalformedClassObject.MalformedClass].getSimpleName
}
assert(err.getMessage === "Malformed class name")

assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) ===
"UtilsSuite$MalformedClassObject$MalformedClass")
}
}

private class SimpleExtension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.util.Utils

/**
* A small wrapper that defines a training session for an estimator, and some methods to log
Expand All @@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (

private val id = UUID.randomUUID()
private val prefix = {
val className = estimator.getClass.getSimpleName
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(estimator.getClass)
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

object TypedAggregateExpression {
def apply[BUF : Encoder, OUT : Encoder](
Expand Down Expand Up @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
s"$nodeName($input)"
}

override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
// aggregator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$");
}

// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ trait DataSourceV2StringFormat {

private def sourceName: String = source match {
case registered: DataSourceRegister => registered.shortName()
case _ => source.getClass.getSimpleName.stripSuffix("$")
// source.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
case _ => Utils.getSimpleName(source.getClass)
}

def metadataString: String = {
Expand Down

0 comments on commit cc88d7f

Please sign in to comment.