Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not safe in scala #21276

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,66 @@ private[spark] object Utils extends Logging {
HashCodes.fromBytes(secretBytes).toString()
}

/**
* A safer version than scala obj's getClass.getSimpleName and Utils.getFormattedClassName
* which may throw Malformed class name error.
* This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
*/
def getSimpleName(fullyQualifiedName: String): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing the method signature into def getSimpleName(obj: AnyRef)?
Then, we simply handle two cases; one is correct and the other non-correct?, e.g.,

def getSimpleName(obj: AnyRef): String = {
  if (incorrect case) {
    // Canonicalizes the name for correction
    ...
  } else {
    // If no problem, just returns getSimplename
    obj.getClass.getSimpleName
  }
}

Copy link
Author

@fangshil fangshil May 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu a if statement to determine if the name is good or not may not be stable and comprehensive - how about we use try-catch here and only use the alternative if obj.getClass.getSimpleName throws the Malformed class name error

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. since getSimpleName is a method from Class, it makes sense for Utils.getSimpleName to also takes a Class param. Also, Utils.getSimpleName will try to use the Class's getSimpleName first, and only use the alternative if the internal error is caught. This should minimize the potential impact of this patch

stripDollars(parseSimpleName(fullyQualifiedName))
}

/**
* Remove the packages from full qualified class name
*/
private def parseSimpleName(fullyQualifiedName: String): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fullyQualifiedName.split("\\.").takeRight(1)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method was copied from scalatest. I think it makes sense to update in this cleaner way

// Find last dot position
val dotPos = fullyQualifiedName.lastIndexOf('.')
// Need to check the dotPos != fullyQualifiedName.length
if (dotPos != -1 && dotPos != fullyQualifiedName.length) {
fullyQualifiedName.substring(dotPos + 1)
} else {
fullyQualifiedName
}
}

/**
* Remove trailing dollar signs from qualified class name,
* and return the trailing part after the last dollar sign in the middle
*/
private def stripDollars(s: String): String = {
val lastDollarIndex = s.lastIndexOf('$')
if (lastDollarIndex < s.length - 1) {
// The last char is not a dollar sign
if (lastDollarIndex == -1 || !s.contains("$iw")) {
// The name does not have dollar sign or is not an intepreter
// generated class, so we should return the full string
s
} else {
// The class name is intepreter generated,
// return the part after the last dollar sign
// This is the same behavior as getClass.getSimpleName
s.substring(lastDollarIndex + 1)
}
}
else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style:

if (...) {
} else {
}

// The last char is a dollar sign
// Find last non-dollar char
val lastNonDollarChar = s.reverse.find(_ != '$')
lastNonDollarChar match {
case None => s
case Some(c) =>
val lastNonDollarIndex = s.lastIndexOf(c)
if (lastNonDollarIndex == -1) {
s
} else {
// Strip the trailing dollar signs
// Invoke stripDollars again to get the simple name
stripDollars(s.substring(0, lastNonDollarIndex + 1))
}
}
}
}
}

private[util] object CallerContext extends Logging {
Expand Down
26 changes: 26 additions & 0 deletions core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,32 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
}
}

object A {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: more meaningful name, e.g.:

object MalformedClassObect {
  class MalformedClass
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

class B
}

test("Safe getSimpleName") {
val fullname1 = "org.apache.spark.TestClass$MyClass"
assert(Utils.getSimpleName(fullname1) === "TestClass$MyClass")

val fullname2 = "org.apache.spark.TestClass$MyClass$"
assert(Utils.getSimpleName(fullname2) === "TestClass$MyClass")

val fullname3 = "org.apache.spark.TestClass$MyClass$1$"
assert(Utils.getSimpleName(fullname3) === "TestClass$MyClass$1")

val fullname4 = "TestClass$MyClass$1$"
assert(Utils.getSimpleName(fullname4) === "TestClass$MyClass$1")

val fullname5 = "$iwC$iwC$$iwC$$iwC$TestClass$MyClass$"
assert(Utils.getSimpleName(fullname5) === "MyClass")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have a case which causes Malformed class name error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the test to add a case that shows Class's getSimpleName throws Malformed class name error and the Utils.getSimpleName works


intercept[java.lang.InternalError] {
classOf[A.B].getSimpleName
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the err message?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

assert(Utils.getSimpleName(classOf[A.B].getName) === "UtilsSuite$A$B")
}
}

private class SimpleExtension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

object TypedAggregateExpression {
def apply[BUF : Encoder, OUT : Encoder](
Expand Down Expand Up @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
s"$nodeName($input)"
}

override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
// aggregator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
override def nodeName: String = Utils.getSimpleName(aggregator.getClass.getName);
}

// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.
Expand Down