Skip to content

Commit

Permalink
implement it in the optimization rule and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fusheng committed Oct 8, 2024
1 parent 92e79e3 commit 9c7e65b
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.{FillStaticPartitions, PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
Expand Down Expand Up @@ -94,7 +94,8 @@ class SparkOptimizer(
ConstantFolding) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) :+
Batch("RewriteCollationJoin", Once, RewriteCollationJoin)
Batch("RewriteCollationJoin", Once, RewriteCollationJoin) :+
Batch("FillStaticPartitions", Once, FillStaticPartitions)

override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
ExtractPythonUDFFromJoinCondition.ruleName :+
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Literal, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{AS_OF_JOIN, EXCEPT, INNER_LIKE_JOIN, INTERSECT, JOIN, LATERAL_JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, UNION}

object FillStaticPartitions extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning(!_.containsAnyPattern(OUTER_JOIN, JOIN, LATERAL_JOIN, AS_OF_JOIN,
INNER_LIKE_JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, INTERSECT, EXCEPT, UNION)) {
case i @ InsertIntoHadoopFsRelationCommand(
_, _, _, partitionColumns, _, _, _, query, _, _, _, _, _)
if i.catalogTable.nonEmpty && i.staticPartitions.isEmpty &&
i.fillStaticPartitions.isEmpty =>
val fillStaticPartitions = mutable.Map[String, String]()

query foreach {
// exclude the case that the project contains partition column that will be computed
case _ @ Project(projectList, _) =>
val partitionColumnContainsEval =
!projectList.filter(x => partitionColumns.map(_.name).contains(x.name))
.map { project =>
val leaves = project.collectLeaves()
leaves.size == 1 && leaves.head.isInstanceOf[AttributeReference]
}.reduceLeft(_ && _)
if (partitionColumnContainsEval) {
return i
}
case _ @ PhysicalOperation(_, filters,
logicalRelation @
LogicalRelation(fsRelation @
HadoopFsRelation(
_,
partitionSchema,
_,
_,
_,
_),
_,
_,
_))
if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
logicalRelation.output)
val (partitionKeyFilters, _) = DataSourceUtils
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)

partitionKeyFilters.map {
case EqualTo(AttributeReference(name, _, _, _), Literal(value, _)) =>
fillStaticPartitions += (name -> value.toString)
case In(AttributeReference(name, _, _, _), list @ Seq(Literal(value, _)))
if list.size == 1 => fillStaticPartitions += (name -> value.toString)
case _ => // do nothing
}
case _ => // do nothing
}

i.copy(outputPath = i.outputPath,
staticPartitions = i.staticPartitions,
ifPartitionNotExists = i.ifPartitionNotExists,
partitionColumns = i.partitionColumns,
bucketSpec = i.bucketSpec,
fileFormat = i.fileFormat,
options = i.options,
query = query,
mode = i.mode,
catalogTable = i.catalogTable,
fileIndex = i.fileIndex,
outputColumnNames = i.outputColumnNames,
fillStaticPartitions.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ case class InsertIntoHadoopFsRelationCommand(
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
outputColumnNames: Seq[String],
fillStaticPartitions: Map[String, String] = Map.empty)
extends V1WriteCommand {

private lazy val parameters = CaseInsensitiveMap(options)
Expand Down Expand Up @@ -102,7 +103,7 @@ case class InsertIntoHadoopFsRelationCommand(
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
catalogTable.get.identifier, Some(staticPartitions))
catalogTable.get.identifier, Some(staticPartitions ++ fillStaticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
Expand Down
175 changes: 174 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@
package org.apache.spark.sql

import org.apache.spark.{SparkConf, SparkNumberFormatException, SparkThrowable}
import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.CommandResult
import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.execution.CommandResultExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -569,6 +576,172 @@ class FileSourceSQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSe
super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, format)
}

def getInsertIntoHadoopFsRelationCommandPartitionMembers(df: DataFrame):
(Seq[CatalogTablePartition], TablePartitionSpec, Map[String, String], Boolean) = {
val commandResults = df.queryExecution.optimizedPlan.collect {
case _ @ CommandResult(_, _, commandPhysicalPlan, _) =>
commandPhysicalPlan match {
case d: DataWritingCommandExec => d.asInstanceOf[DataWritingCommandExec].cmd
.asInstanceOf[InsertIntoHadoopFsRelationCommand]
case a: AdaptiveSparkPlanExec if a.inputPlan.isInstanceOf[DataWritingCommandExec] =>
a.inputPlan.asInstanceOf[DataWritingCommandExec].cmd
.asInstanceOf[InsertIntoHadoopFsRelationCommand]
}
}
val insertIntoHadoopFsRelationCommand = commandResults.head
val matchingPartitions = spark.sessionState.catalog.listPartitions(
insertIntoHadoopFsRelationCommand.catalogTable.get.identifier,
Some(insertIntoHadoopFsRelationCommand.staticPartitions
++ insertIntoHadoopFsRelationCommand.fillStaticPartitions))
val staticPartitions = insertIntoHadoopFsRelationCommand.staticPartitions
val fillStaticPartitions = insertIntoHadoopFsRelationCommand
.fillStaticPartitions
val dynamicPartitionOverwrite = insertIntoHadoopFsRelationCommand.dynamicPartitionOverwrite
(matchingPartitions, staticPartitions, fillStaticPartitions, dynamicPartitionOverwrite)
}

test("SPARK-48881: test some dynamic partitions can be compensated to " +
"specific partition values") {
withTable("A", "B") {
spark.sessionState.conf.setConfString("spark.sql.sources.partitionOverwriteMode",
PartitionOverwriteMode.DYNAMIC.toString)
sql("create table A(id int) using parquet partitioned by " +
"(p1 string, p2 string)")
spark.range(10).selectExpr("id").withColumns(Seq("p1", "p2"),
Seq(col("id").cast("string"), col("id").cast("string"))).write
.partitionBy("p1", "p2").saveAsTable("B")

sql("insert overwrite A partition(p1='20240701', p2='1') values(11)")
sql("insert overwrite A partition(p1='20240702', p2='2') values(12)")

// insert overwrite t1 partition(p) select p from t where p = 1
// this situation will be optimized by this optimization rule.
val df = sql("insert overwrite A partition(p1, p2) " +
"select id, p1, p2 from B where p1 in ('20240712') and id = 8")
val (matchingPartitions, staticPartitions, fillStaticPartitions, dynamicPartitionOverwrite) =
getInsertIntoHadoopFsRelationCommandPartitionMembers(df)

assert(staticPartitions.isEmpty)
assert(fillStaticPartitions == Map("p1" -> "20240712"))
assert(dynamicPartitionOverwrite)
assert(matchingPartitions.isEmpty)

// insert overwrite t1 partition(p) select p from t where p = 1
// this situation will not be optimized by this optimization rule.
val df2 = sql("insert overwrite A partition(p1='20240712', p2) " +
"select id, p2 from B where p1 in ('20240712') and id = 8")
val (matchingPartitions2, staticPartitions2, fillStaticPartitions2,
dynamicPartitionOverwrite2) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df2)
assert(staticPartitions2 == Map("p1" -> "20240712"))
assert(fillStaticPartitions2.isEmpty)
assert(dynamicPartitionOverwrite2)
assert(matchingPartitions2.isEmpty)

// insert overwrite t1 partition(p='1') select p from t where p = 1
// this situation will not be optimized by this optimization rule.
val df3 = sql("insert overwrite A partition(p1='20240712', p2='1') " +
"select id from B where p1 in ('20240712') and id = 8")
val (matchingPartitions3, staticPartitions3, fillStaticPartitions3,
dynamicPartitionOverwrite3) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df3)

assert(staticPartitions3 == Map("p1" -> "20240712", "p2" -> "1"))
assert(fillStaticPartitions3.isEmpty)
assert(dynamicPartitionOverwrite3)
assert(matchingPartitions3.map(_.spec) == Seq(
Map("p1" -> "20240712", "p2" -> "1")
))

// union situation will not be optimized by this optimization rule.
val df4 = sql("insert overwrite A partition(p1, p2) " +
"select id, p1, p2 from B where p1 in ('20240712') and id = 8 " +
"union select id, p1, p2 from A where p1 in ('20240713') and id = 8 ")
val (matchingPartitions4, staticPartitions4, fillStaticPartitions4,
dynamicPartitionOverwrite4) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df4)

assert(staticPartitions4.isEmpty)
assert(fillStaticPartitions4.isEmpty)
assert(dynamicPartitionOverwrite4)
assert(matchingPartitions4.map(_.spec).sortWith((x, y) =>
x.values.toSeq.head.toLong < y.values.toSeq.head.toLong) == Seq(
Map("p1" -> "20240701", "p2" -> "1"),
Map("p1" -> "20240702", "p2" -> "2"),
Map("p1" -> "20240712", "p2" -> "1")
))

// join situation will not be optimized by this optimization rule.
val df5 = sql("insert overwrite A partition(p1, p2) " +
"select t1.id, t1.p1, t1.p2 from B t1 left join A t2 " +
"on t1.id = t2.id where t1.p1 in ('20240712') and t1.id = 8 and " +
"t2.p1 in ('20240713') and t2.id = 8")
val (matchingPartitions5, staticPartitions5, fillStaticPartitions5,
dynamicPartitionOverwrite5) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df5)

assert(staticPartitions5.isEmpty)
assert(fillStaticPartitions5.isEmpty)
assert(dynamicPartitionOverwrite5)
assert(matchingPartitions5.map(_.spec).sortWith((x, y) =>
x.values.toSeq.head.toLong < y.values.toSeq.head.toLong) == Seq(
Map("p1" -> "20240701", "p2" -> "1"),
Map("p1" -> "20240702", "p2" -> "2"),
Map("p1" -> "20240712", "p2" -> "1")
))

// insert into table t1 partition(p) select (p+1) as p from t where p = 1
// this situation will not be optimized by this optimization rule.
val df6 = sql("insert overwrite A partition(p1, p2) " +
"select id, (p1 + 1) as p1, p2 from B " +
"where p1 in ('20240712') and id = 8")
val (matchingPartitions6, staticPartitions6, fillStaticPartitions6,
dynamicPartitionOverwrite6) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df6)
assert(staticPartitions6.isEmpty)
assert(fillStaticPartitions6.isEmpty)
assert(dynamicPartitionOverwrite6)
assert(matchingPartitions6.size == 3 && matchingPartitions6.map(_.spec).sortWith((x, y) =>
x.values.toSeq.head.toLong < y.values.toSeq.head.toLong) == Seq(
Map("p1" -> "20240701", "p2" -> "1"),
Map("p1" -> "20240702", "p2" -> "2"),
Map("p1" -> "20240712", "p2" -> "1")
))

// insert into table t1 partition(p) select p from t where p = 1 and p = 2
// this situation will not be optimized by this optimization rule.
val df7 = sql("insert overwrite A partition(p1, p2) " +
"select id, p1, p2 from B " +
"where p1 = '20240712' and p1 = '20240713' and id = 8")
val (matchingPartitions7, staticPartitions7, fillStaticPartitions7,
dynamicPartitionOverwrite7) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df7)

assert(staticPartitions7.isEmpty)
assert(fillStaticPartitions7.isEmpty)
assert(dynamicPartitionOverwrite7)
assert(matchingPartitions7.size == 3 && matchingPartitions6.map(_.spec).sortWith((x, y) =>
x.values.toSeq.head.toLong < y.values.toSeq.head.toLong) == Seq(
Map("p1" -> "20240701", "p2" -> "1"),
Map("p1" -> "20240702", "p2" -> "2"),
Map("p1" -> "20240712", "p2" -> "1")
) )

// insert into table t1 partition(p1, p2) select * from t2 where p2=1
// this situation can be optimized by using this optimization rule to avoid obtaining
// all table partitions, which is in line with expectations.
val df8 = sql("insert overwrite A partition(p1, p2) " +
"select id, p1, p2 from B " +
"where p2 = '1' and id = 8")
val (matchingPartitions8, staticPartitions8, fillStaticPartitions8,
dynamicPartitionOverwrite8) = getInsertIntoHadoopFsRelationCommandPartitionMembers(df8)

assert(staticPartitions8.isEmpty)
assert(fillStaticPartitions8.size == 1 &&
fillStaticPartitions8("p2") == "1")
assert(dynamicPartitionOverwrite8)
assert(matchingPartitions8.size == 2 && matchingPartitions8.map(_.spec).sortWith((x, y) =>
x.values.toSeq.head.toLong < y.values.toSeq.head.toLong) == Seq(
Map("p1" -> "20240701", "p2" -> "1"),
Map("p1" -> "20240712", "p2" -> "1")
))
}
}

}

class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession {
Expand Down

0 comments on commit 9c7e65b

Please sign in to comment.