From f1b68d897e49e77308fb75bb60d054db10f6a90c Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Tue, 19 Nov 2024 20:25:52 +0800 Subject: [PATCH] [SPARK-50315][SQL] Support custom metrics for V1Fallback writes ### What changes were proposed in this pull request? Support for custom metrics for V1Fallback writers (AppendDataExecV1, OverwriteByExpressionExecV1) ### Why are the changes needed? * Add the custom metrics of the V1Write as metrics to the V1FallbackWriters implementations * Publish the metrics from reportDriverMetrics ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a new test on top of mocked implementations ### Was this patch authored or co-authored using generative AI tooling? No Closes #48867 from olaky/sc-50315-metrics-for-v1-fallback-writers. Authored-by: Ole Sasse Signed-off-by: Wenchen Fan --- .../datasources/v2/V1FallbackWriters.scala | 23 ++++++-- .../org/apache/spark/sql/QueryTest.scala | 25 ++++++++- .../sql/connector/V1WriteFallbackSuite.scala | 54 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 6f83b82785955..801151c51206d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.SupportsWrite import org.apache.spark.sql.connector.write.V1Write -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.sources.InsertableRelation /** @@ -58,14 +59,27 @@ case class OverwriteByExpressionExecV1( sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write { override def output: Seq[Attribute] = Nil + override val metrics: Map[String, SQLMetric] = + write.supportedCustomMetrics().map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) + }.toMap + def table: SupportsWrite def refreshCache: () => Unit def write: V1Write override def run(): Seq[InternalRow] = { - val writtenRows = writeWithV1(write.toInsertableRelation) + writeWithV1(write.toInsertableRelation) refreshCache() - writtenRows + + write.reportDriverMetrics().foreach { customTaskMetric => + metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value())) + } + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + + Nil } } @@ -75,8 +89,7 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { + protected def writeWithV1(relation: InsertableRelation): Unit = { relation.insert(Dataset.ofRows(session, plan), overwrite = false) - Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f5ba655e3e85f..30180d48da71a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -27,8 +27,9 @@ import org.scalatest.Assertions import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -447,6 +448,28 @@ object QueryTest extends Assertions { case None => } } + + def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): Seq[SparkPlan] = { + var capturedPlans = Seq.empty[SparkPlan] + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + capturedPlans = capturedPlans :+ qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + spark.listenerManager.register(listener) + try { + thunk + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + } finally { + spark.listenerManager.unregister(listener) + } + + capturedPlans + } } class QueryTestSuite extends QueryTest with test.SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index ad31cf84eeb3f..04fc7e23ebb24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -31,9 +32,12 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.v2.{AppendDataExecV1, OverwriteByExpressionExecV1} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf.{OPTIMIZER_MAX_ITERATIONS, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.sources._ @@ -198,6 +202,43 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-50315: metrics for V1 fallback writers") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + + def captureWrite(sparkSession: SparkSession)(thunk: => Unit): SparkPlan = { + val physicalPlans = withPhysicalPlansCaptured(sparkSession, thunk) + val v1FallbackWritePlans = physicalPlans.filter { + case _: AppendDataExecV1 | _: OverwriteByExpressionExecV1 => true + case _ => false + } + + assert(v1FallbackWritePlans.size === 1) + v1FallbackWritePlans.head + } + + val appendPlan = captureWrite(session) { + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + } + assert(appendPlan.metrics("numOutputRows").value === 1) + + val overwritePlan = captureWrite(session) { + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").overwrite(lit(true)) + } + assert(overwritePlan.metrics("numOutputRows").value === 1) + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -376,10 +417,23 @@ class InMemoryTableWithV1Fallback( } override def build(): V1Write = new V1Write { + case class SupportedV1WriteMetric(name: String, description: String) extends CustomSumMetric + + override def supportedCustomMetrics(): Array[CustomMetric] = + Array(SupportedV1WriteMetric("numOutputRows", "Number of output rows")) + + private var writeMetrics = Array.empty[CustomTaskMetric] + + override def reportDriverMetrics(): Array[CustomTaskMetric] = writeMetrics + override def toInsertableRelation: InsertableRelation = { (data: DataFrame, overwrite: Boolean) => { assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true") val rows = data.collect() + + case class V1WriteTaskMetric(name: String, value: Long) extends CustomTaskMetric + writeMetrics = Array(V1WriteTaskMetric("numOutputRows", rows.length)) + rows.groupBy(getPartitionValues).foreach { case (partition, elements) => if (dataMap.contains(partition) && mode == "append") { dataMap.put(partition, dataMap(partition) ++ elements)