Skip to content

Commit

Permalink
[SPARK-44287][SQL][FOLLOWUP] Set partition index correctly
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a followup of #41839, to set the partition index correctly even if it's not used for now. It also contains a few code cleanup.

### Why are the changes needed?

future-proof

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #42185 from cloud-fan/follow.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit bf1bbc5)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Jul 28, 2023
1 parent 4f90c32 commit 4f71878
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,16 @@ case class ColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition w
)

override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val numInputBatches = longMetric("numInputBatches")
val evaluatorFactory =
new ColumnarToRowEvaluatorFactory(
child.output,
numOutputRows,
numInputBatches)

val evaluatorFactory = new ColumnarToRowEvaluatorFactory(
child.output,
longMetric("numOutputRows"),
longMetric("numInputBatches"))
if (conf.usePartitionEvaluator) {
child.executeColumnar().mapPartitionsWithEvaluator(evaluatorFactory)
} else {
child.executeColumnar().mapPartitionsInternal { batches =>
child.executeColumnar().mapPartitionsWithIndexInternal { (index, batches) =>
val evaluator = evaluatorFactory.createEvaluator()
evaluator.eval(0, batches)
evaluator.eval(index, batches)
}
}
}
Expand Down Expand Up @@ -454,25 +450,20 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition {
)

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do
// combine with some of the Arrow conversion tools we will need to unify some of the configs.
val numRows = conf.columnBatchSize
val evaluatorFactory =
new RowToColumnarEvaluatorFactory(
conf.offHeapColumnVectorEnabled,
numRows,
schema,
numInputRows,
numOutputBatches)

val evaluatorFactory = new RowToColumnarEvaluatorFactory(
conf.offHeapColumnVectorEnabled,
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do
// combine with some of the Arrow conversion tools we will need to unify some of the configs.
conf.columnBatchSize,
schema,
longMetric("numInputRows"),
longMetric("numOutputBatches"))
if (conf.usePartitionEvaluator) {
child.execute().mapPartitionsWithEvaluator(evaluatorFactory)
} else {
child.execute().mapPartitionsInternal { rowIterator =>
child.execute().mapPartitionsWithIndexInternal { (index, rowIterator) =>
val evaluator = evaluatorFactory.createEvaluator()
evaluator.eval(0, rowIterator)
evaluator.eval(index, rowIterator)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,40 +279,36 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
Seq(true, false).foreach { enableEvaluator =>
withSQLConf(SQLConf.USE_PARTITION_EVALUATOR.key -> enableEvaluator.toString) {
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
// perform a join to inject a broadcast exchange
val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
val data = left.join(right, $"l1" === $"r1")
// repartitioning avoids having the add operation pushed up into the LocalTableScan
.repartition(1)
val df = data.selectExpr("l2 + r2")
// execute the plan so that the final adaptive plan is available when AQE is on
df.collect()
val found = collectPlanSteps(df.queryExecution.executedPlan).sum
// 1 MyBroadcastExchangeExec
// 1 MyShuffleExchangeExec
// 1 ColumnarToRowExec
// 2 ColumnarProjectExec
// 1 ReplacedRowToColumnarExec
// so 11121 is expected.
assert(found == 11121)

// Verify that we get back the expected, wrong, result
val result = df.collect()
assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
assert(result(1).getLong(0) == 201L)
assert(result(2).getLong(0) == 301L)

withTempPath { path =>
val e = intercept[Exception](df.write.parquet(path.getCanonicalPath))
assert(e.getMessage == "columnar write")
}
}
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
// perform a join to inject a broadcast exchange
val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
val data = left.join(right, $"l1" === $"r1")
// repartitioning avoids having the add operation pushed up into the LocalTableScan
.repartition(1)
val df = data.selectExpr("l2 + r2")
// execute the plan so that the final adaptive plan is available when AQE is on
df.collect()
val found = collectPlanSteps(df.queryExecution.executedPlan).sum
// 1 MyBroadcastExchangeExec
// 1 MyShuffleExchangeExec
// 1 ColumnarToRowExec
// 2 ColumnarProjectExec
// 1 ReplacedRowToColumnarExec
// so 11121 is expected.
assert(found == 11121)

// Verify that we get back the expected, wrong, result
val result = df.collect()
assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used.
assert(result(1).getLong(0) == 201L)
assert(result(2).getLong(0) == 301L)

withTempPath { path =>
val e = intercept[Exception](df.write.parquet(path.getCanonicalPath))
assert(e.getMessage == "columnar write")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,18 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {

test("SPARK-37779: ColumnarToRowExec should be canonicalizable after being (de)serialized") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") {
Seq(true, false).foreach { enable =>
withSQLConf(SQLConf.USE_PARTITION_EVALUATOR.key -> enable.toString) {
withTempPath { path =>
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
()
}
} catch {
case e: Throwable => fail("ColumnarToRowExec was not canonicalizable", e)
}
withTempPath { path =>
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
()
}
} catch {
case e: Throwable => fail("ColumnarToRowExec was not canonicalizable", e)
}
}
}
Expand Down

0 comments on commit 4f71878

Please sign in to comment.