Skip to content

Commit

Permalink
[SPARK-3414][SQL] Replace LowerCaseSchema with Resolver
Browse files Browse the repository at this point in the history
**This PR introduces a subtle change in semantics for HiveContext when using the results in Python or Scala.  Specifically, while resolution remains case insensitive, it is now case preserving.**

_This PR is a follow up to #2293 (and to a lesser extent #2262 #2334)._

In #2293 the catalog was changed to store analyzed logical plans instead of unresolved ones.  While this change fixed the reported bug (which was caused by yet another instance of us forgetting to put in a `LowerCaseSchema` operator) it had the consequence of breaking assumptions made by `MultiInstanceRelation`.  Specifically, we can't replace swap out leaf operators in a tree without rewriting changed expression ids (which happens when you self join the same RDD that has been registered as a temp table).

In this PR, I instead remove the need to insert `LowerCaseSchema` operators at all, by moving the concern of matching up identifiers completely into analysis.  Doing so allows the test cases from both #2293 and #2262 to pass at the same time (and likely fixes a slew of other "unknown unknown" bugs).

While it is rolled back in this PR, storing the analyzed plan might actually be a good idea.  For instance, it is kind of confusing if you register a temporary table, change the case sensitivity of resolution and now you can't query that table anymore.  This can be addressed in a follow up PR.

Follow-ups:
 - Configurable case sensitivity
 - Consider storing analyzed plans for temp tables

Author: Michael Armbrust <michael@databricks.com>

Closes #2382 from marmbrus/lowercase and squashes the following commits:

c21171e [Michael Armbrust] Ensure the resolver is used for field lookups and ensure that case insensitive resolution is still case preserving.
d4320f1 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into lowercase
2de881e [Michael Armbrust] Address comments.
219805a [Michael Armbrust] style
5b93711 [Michael Armbrust] Replace LowerCaseSchema with Resolver.
  • Loading branch information
marmbrus committed Sep 20, 2014
1 parent 7f54580 commit 293ce85
Show file tree
Hide file tree
Showing 15 changed files with 125 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean)
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {

val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution

// TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100)

Expand All @@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
(if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*),
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
Expand Down Expand Up @@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase.
*/
object LowercaseAttributeReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case UnresolvedRelation(databaseName, name, alias) =>
UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase))
case Subquery(alias, child) => Subquery(alias.toLowerCase, child)
case q: LogicalPlan => q transformExpressions {
case s: Star => s.copy(table = s.table.map(_.toLowerCase))
case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase)
case Alias(c, name) => Alias(c, name.toLowerCase)()
case GetField(c, name) => GetField(c, name.toLowerCase)
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete
* [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's
Expand All @@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolveChildren(name).getOrElse(u)
val result = q.resolveChildren(name, resolver).getOrElse(u)
logDebug(s"Resolving $u to $result")
result
}
Expand All @@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolveChildren)
val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })

val missingInProject = requiredAttributes -- p.output
Expand All @@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Sort(ordering,
Project(projectList ++ missingInProject, child)))
} else {
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
Expand All @@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
)

logDebug(s"Grouping expressions: $groupingRelation")
val resolved = unresolved.flatMap(groupingRelation.resolve)
val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) {
Expand Down Expand Up @@ -258,22 +242,22 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child.output)
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
},
child)
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child.output)
case s: Star => s.expand(t.child.output, resolver)
case o => o :: Nil
}
)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
a.copy(
aggregateExpressions = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child.output)
case s: Star => s.expand(a.child.output, resolver)
case o => o :: Nil
}
)
Expand All @@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/**
* Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are
* only required to provide scoping information for attributes and can be removed once analysis is
* complete. Similarly, this node also removes
* [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators.
* complete.
*/
object EliminateAnalysisOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Subquery(_, child) => child
case LowerCaseSchema(child) => child
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,14 @@ package org.apache.spark.sql.catalyst
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
* into fully typed objects using information in a schema [[Catalog]].
*/
package object analysis
package object analysis {

/**
* Responsible for resolving which identifiers refer to the same entity. For example, by using
* case insensitive equality.
*/
type Resolver = (String, String) => Boolean

val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b)
val caseSensitiveResolution = (a: String, b: String) => a == b
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def newInstance = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = UnresolvedAttribute(name)

// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
Expand Down Expand Up @@ -97,13 +98,14 @@ case class Star(
override def newInstance = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = this

def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
// If there is no table specified, use all input attributes.
case None => input
// If there is a table, pick out attributes that are part of this table.
case Some(t) => input.filter(_.qualifiers contains t)
case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
}
val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
case (n: NamedExpression, _) => n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ abstract class Attribute extends NamedExpression {

def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute

def toAttribute = this
def newInstance: Attribute
Expand Down Expand Up @@ -86,7 +87,6 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType
override def nullable = child.nullable


override def toAttribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
Expand Down Expand Up @@ -144,6 +144,14 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
}
}

override def withName(newName: String): AttributeReference = {
if (name == newName) {
this
} else {
AttributeReference(newName, dataType, nullable)(exprId, qualifiers)
}
}

/**
* Returns a copy of this [[AttributeReference]] with new qualifiers.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees

abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
self: Product =>

/**
Expand Down Expand Up @@ -75,42 +77,95 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(name: String): Option[NamedExpression] =
resolve(name, children.flatMap(_.output))
def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, children.flatMap(_.output), resolver)

/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`.
*/
def resolve(name: String): Option[NamedExpression] =
resolve(name, output)
def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, output, resolver)

/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
protected def resolve(
name: String,
input: Seq[Attribute],
resolver: Resolver): Option[NamedExpression] = {

val parts = name.split("\\.")

// Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the
// name. Return these matches along with any remaining parts, which represent dotted access to
// struct fields.
val options = input.flatMap { option =>
// If the first part of the desired name matches a qualifier for this possible match, drop it.
val remainingParts =
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) {
parts.drop(1)
} else {
parts
}

if (resolver(option.name, remainingParts.head)) {
// Preserve the case of the user's attribute reference.
(option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil
} else {
Nil
}
}

options.distinct match {
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
// One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case Seq() => None // No matches.
val aliased =
Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)

// No matches.
case Seq() =>
logTrace(s"Could not find $name in ${input.mkString(", ")}")
None

// More than one match.
case ambiguousReferences =>
throw new TreeNodeException(
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
actualField match {
case Seq() =>
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
case Seq(singleMatch) =>
resolveNesting(rest, GetField(expression, singleMatch.name), resolver)
case multipleMatches =>
sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,32 +165,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output = child.output.map(_.withQualifiers(alias :: Nil))
}

/**
* Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences
* this allows for optional case insensitive attribute resolution. This node can be elided after
* analysis.
*/
case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
protected def lowerCaseSchema(dataType: DataType): DataType = dataType match {
case StructType(fields) =>
StructType(fields.map(f =>
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
case otherType => otherType
}

override val output = child.output.map {
case a: AttributeReference =>
AttributeReference(
a.name.toLowerCase,
lowerCaseSchema(a.dataType),
a.nullable)(
a.exprId,
a.qualifiers)
case other => other
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
catalog.registerTable(None, tableName, rdd.queryExecution.analyzed)
catalog.registerTable(None, tableName, rdd.queryExecution.logical)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("SPARK-3349 partitioning after limit") {
/*
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2)
.registerTempTable("subset1")
Expand All @@ -396,7 +395,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
(1, "a", 1) ::
(2, "b", 2) :: Nil)
*/
}

test("mixed-case keywords") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {

LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias))
}
}
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog

// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
Expand Down
Loading

0 comments on commit 293ce85

Please sign in to comment.