Skip to content

Commit

Permalink
refactor default column value resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 23, 2023
1 parent 411bcd2 commit be16454
Show file tree
Hide file tree
Showing 16 changed files with 510 additions and 1,056 deletions.
4 changes: 2 additions & 2 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3176,14 +3176,14 @@
"_LEGACY_ERROR_TEMP_1202" : {
"message" : [
"Cannot write to '<tableName>', too many data columns:",
"Table columns: <tableColumns>.",
"Table columns (excluding columns with static partition values): <tableColumns>.",
"Data columns: <dataColumns>."
]
},
"_LEGACY_ERROR_TEMP_1203" : {
"message" : [
"Cannot write to '<tableName>', not enough data columns:",
"Table columns: <tableColumns>.",
"Table columns (excluding columns with static partition values): <tableColumns>.",
"Data columns: <dataColumns>."
]
},
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand All @@ -55,8 +55,7 @@ import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssig
import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.collection.{Utils => CUtils}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and
Expand Down Expand Up @@ -280,7 +279,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
KeepLegacyOutputs),
Batch("Resolution", fixedPoint,
new ResolveCatalogs(catalogManager) ::
ResolveUserSpecifiedColumns ::
ResolveInsertInto ::
ResolveRelations ::
ResolvePartitionSpec ::
Expand Down Expand Up @@ -313,7 +311,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TimeWindowing ::
SessionWindowing ::
ResolveWindowTime ::
ResolveDefaultColumns(ResolveRelations.resolveRelationOrTempView) ::
ResolveInlineTables ::
ResolveLambdaVariables ::
ResolveTimeZone ::
Expand Down Expand Up @@ -1080,7 +1077,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

def apply(plan: LogicalPlan)
: LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) =>
val relation = table match {
case u: UnresolvedRelation if !u.isStreaming =>
resolveRelation(u).getOrElse(u)
Expand Down Expand Up @@ -1278,53 +1275,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

/** Handle INSERT INTO for DSv2 */
object ResolveInsertInto extends Rule[LogicalPlan] {

/** Add a project to use the table column names for INSERT INTO BY NAME */
private def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = {
SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver)

if (i.userSpecifiedCols.size != i.query.output.size) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
i.userSpecifiedCols.size, i.query.output.size, i.query)
}
val projectByName = i.userSpecifiedCols.zip(i.query.output)
.map { case (userSpecifiedCol, queryOutputCol) =>
val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver)
.getOrElse(
throw QueryCompilationErrors.unresolvedAttributeError(
"UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin))
(queryOutputCol.dataType, resolvedCol.dataType) match {
case (input: StructType, expected: StructType) =>
// Rename inner fields of the input column to pass the by-name INSERT analysis.
Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)()
case _ =>
Alias(queryOutputCol, resolvedCol.name)()
}
}
Project(projectByName, i.query)
}

private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = {
if (input.length == expected.length) {
val newFields = input.zip(expected).map { case (f1, f2) =>
(f1.dataType, f2.dataType) match {
case (s1: StructType, s2: StructType) =>
f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2))
case _ =>
f1.copy(name = f2.name)
}
}
StructType(newFields)
} else {
input
}
}

object ResolveInsertInto extends ResolveInsertionBase {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
Expand Down Expand Up @@ -1527,6 +1481,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need
// to resolve column "DEFAULT" in the child plans so that they must be unresolved.
case i: InsertIntoStatement => ResolveColumnDefaultInInsert(i)

// Wait for other rules to resolve child plans first
case p: LogicalPlan if !p.childrenResolved => p

Expand Down Expand Up @@ -1646,6 +1604,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// implementation and should be resolved based on the table schema.
o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table))

case u: UpdateTable => ResolveReferencesInUpdate(u)

case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _)
if !m.resolved && targetTable.resolved && sourceTable.resolved =>

Expand Down Expand Up @@ -1796,23 +1756,31 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable)
case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable)
}
resolveMergeExprOrFail(c, resolvePlan)
resolvedKey match {
case attr: AttributeReference =>
val resolvedExpr = resolveExprInAssignment(c, resolvePlan) match {
case u: UnresolvedAttribute if isExplicitDefaultColumn(u) =>
getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType))
case other => other
}
checkResolvedMergeExpr(resolvedExpr, resolvePlan)
resolvedExpr
case _ => resolveMergeExprOrFail(c, resolvePlan)
}
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}

private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
val resolved = resolveExpressionByPlanChildren(e, p)
resolved.references.filter { attribute: Attribute =>
!attribute.resolved &&
// We exclude attribute references named "DEFAULT" from consideration since they are
// handled exclusively by the ResolveDefaultColumns analysis rule. That rule checks the
// MERGE command for such references and either replaces each one with a corresponding
// value, or returns a custom error message.
normalizeFieldName(attribute.name) != normalizeFieldName(CURRENT_DEFAULT_COLUMN_NAME)
}.foreach { a =>
val resolved = resolveExprInAssignment(e, p)
checkResolvedMergeExpr(resolved, p)
resolved
}

private def checkResolvedMergeExpr(e: Expression, p: LogicalPlan): Unit = {
e.references.filter(!_.resolved).foreach { a =>
// Note: This will throw error only on unresolved attribute issues,
// not other resolution errors like mismatched data types.
val cols = p.inputSet.toSeq.map(_.sql).mkString(", ")
Expand All @@ -1822,10 +1790,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
"sqlExpr" -> a.sql,
"cols" -> cols))
}
resolved match {
case Alias(child: ExtractValue, _) => child
case other => other
}
}

// Expand the star expression using the input plan first. If failed, try resolve
Expand Down Expand Up @@ -3346,53 +3310,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* A special rule to reorder columns for DSv1 when users specify a column list in INSERT INTO.
* DSv2 is handled by [[ResolveInsertInto]] separately.
*/
object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case i: InsertIntoStatement if !i.table.isInstanceOf[DataSourceV2Relation] &&
i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty =>
val resolved = resolveUserSpecifiedColumns(i)
val projection = addColumnListOnQuery(i.table.output, resolved, i.query)
i.copy(userSpecifiedCols = Nil, query = projection)
}

private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = {
SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver)

i.userSpecifiedCols.map { col =>
i.table.resolve(Seq(col), resolver).getOrElse {
val candidates = i.table.output.map(_.qualifiedName)
val orderedCandidates = StringUtils.orderSuggestedIdentifiersBySimilarity(col, candidates)
throw QueryCompilationErrors
.unresolvedAttributeError("UNRESOLVED_COLUMN", col, orderedCandidates, i.origin)
}
}
}

private def addColumnListOnQuery(
tableOutput: Seq[Attribute],
cols: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = {
if (cols.size != query.output.size) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
cols.size, query.output.size, query)
}
val nameToQueryExpr = CUtils.toMap(cols, query.output)
// Static partition columns in the table output should not appear in the column list
// they will be handled in another rule ResolveInsertInto
val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) }
if (reordered == query.output) {
query
} else {
Project(reordered, query)
}
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLiteral
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -103,8 +104,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
case assignment if assignment.key.semanticEquals(attr) => assignment
}
val resolvedValue = if (matchingAssignments.isEmpty) {
errors += s"No assignment for '${attr.name}'"
attr
val defaultExpr = getDefaultValueExprOrNullLiteral(attr, conf)
if (defaultExpr.isEmpty) {
errors += s"No assignment for '${attr.name}'"
}
defaultExpr.getOrElse(attr)
} else if (matchingAssignments.length > 1) {
val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ")
errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ trait ColumnResolutionHelper extends Logging {
allowOuter = allowOuter)
}

def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = {
resolveExpressionByPlanChildren(expr, hostPlan) match {
// Assignment key and value does not need the alias when resolving nested columns.
case Alias(child: ExtractValue, _) => child
case other => other
}
}

private def resolveExpressionByPlanId(
e: Expression,
q: LogicalPlan): Expression = {
Expand Down
Loading

0 comments on commit be16454

Please sign in to comment.