From cc88d7fad16e8b5cbf7b6b9bfe412908782b4a45 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Tue, 12 Jun 2018 12:10:08 -0700 Subject: [PATCH] [SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not safe in scala ## What changes were proposed in this pull request? When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest https://github.com/scalatest/scalatest/pull/1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425. To fix this issue, we follow the solution in https://github.com/scalatest/scalatest/pull/1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName. ## How was this patch tested? added unit test Author: Fangshi Li Closes #21276 from fangshil/SPARK-24216. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 59 ++++++++++++++++++- .../org/apache/spark/util/UtilsSuite.scala | 16 +++++ .../spark/ml/util/Instrumentation.scala | 5 +- .../aggregate/TypedAggregateExpression.scala | 5 +- .../v2/DataSourceV2StringFormat.scala | 4 +- 6 files changed, 89 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 3b469a69437b9..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f9191a59c1655..7428db2158538 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging { HashCodes.fromBytes(secretBytes).toString() } + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3a1c166d46257..11f46eb9e4359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log @@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( private val id = UUID.randomUUID() private val prefix = { - val className = estimator.getClass.getSimpleName + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(estimator.getClass) s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 693e67dcd108e..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -53,7 +53,9 @@ trait DataSourceV2StringFormat { private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = {