Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3414][SQL] Replace LowerCaseSchema with Resolver #2382

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean)
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {

val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution

// TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100)

Expand All @@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
(if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*),
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
Expand Down Expand Up @@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase.
*/
object LowercaseAttributeReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case UnresolvedRelation(databaseName, name, alias) =>
UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase))
case Subquery(alias, child) => Subquery(alias.toLowerCase, child)
case q: LogicalPlan => q transformExpressions {
case s: Star => s.copy(table = s.table.map(_.toLowerCase))
case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase)
case Alias(c, name) => Alias(c, name.toLowerCase)()
case GetField(c, name) => GetField(c, name.toLowerCase)
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete
* [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's
Expand All @@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolveChildren(name).getOrElse(u)
val result = q.resolveChildren(name, resolver).getOrElse(u)
logDebug(s"Resolving $u to $result")
result
}
Expand All @@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolveChildren)
val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })

val missingInProject = requiredAttributes -- p.output
Expand All @@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Sort(ordering,
Project(projectList ++ missingInProject, child)))
} else {
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
Expand All @@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
)

logDebug(s"Grouping expressions: $groupingRelation")
val resolved = unresolved.flatMap(groupingRelation.resolve)
val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) {
Expand Down Expand Up @@ -258,22 +242,22 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child.output)
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
},
child)
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child.output)
case s: Star => s.expand(t.child.output, resolver)
case o => o :: Nil
}
)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
a.copy(
aggregateExpressions = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child.output)
case s: Star => s.expand(a.child.output, resolver)
case o => o :: Nil
}
)
Expand All @@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/**
* Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are
* only required to provide scoping information for attributes and can be removed once analysis is
* complete. Similarly, this node also removes
* [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators.
* complete.
*/
object EliminateAnalysisOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Subquery(_, child) => child
case LowerCaseSchema(child) => child
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ package org.apache.spark.sql.catalyst
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
* into fully typed objects using information in a schema [[Catalog]].
*/
package object analysis
package object analysis {
type Resolver = (String, String) => Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolver probably a general name, can we use a more precise name for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will actually end up providing more general resolution functionality in the long term. I've added some scala doc for clarity though.


val caseInsensitiveResolution = (a: String, b: String) => a.toLowerCase == b.toLowerCase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a.equalIgnoreCase(b)?

val caseSensitiveResolution = (a: String, b: String) => a == b
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ case class Star(
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this

def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
// If there is no table specified, use all input attributes.
case None => input
// If there is a table, pick out attributes that are part of this table.
case Some(t) => input.filter(_.qualifiers contains t)
case Some(t) => input.filter(_.qualifiers.filter(resolver(_,t)).nonEmpty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: space after ,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

}
val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
case (n: NamedExpression, _) => n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
Expand Down Expand Up @@ -75,19 +76,23 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(name: String): Option[NamedExpression] =
resolve(name, children.flatMap(_.output))
def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, children.flatMap(_.output), resolver)

/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`.
*/
def resolve(name: String): Option[NamedExpression] =
resolve(name, output)
def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, output, resolver)

/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
protected def resolve(
name: String,
input: Seq[Attribute],
resolver: Resolver): Option[NamedExpression] = {

val parts = name.split("\\.")
// Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the
Expand All @@ -96,16 +101,27 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
val options = input.flatMap { option =>
// If the first part of the desired name matches a qualifier for this possible match, drop it.
val remainingParts =
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
if (option.qualifiers.filter(resolver(_, parts.head)).nonEmpty && parts.size > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find would be better than filter here.

parts.drop(1)
} else {
parts
}

if (resolver(option.name, remainingParts.head)) {
(option, remainingParts.tail.toList) :: Nil
} else {
Nil
}
}

options.distinct match {
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case Seq() => None // No matches.
case Seq() =>
println(s"Could not find $name in ${input.mkString(", ")}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logTrace instead? As we did in Analyzer.

None // No matches.
case ambiguousReferences =>
throw new TreeNodeException(
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output = child.output.map(_.withQualifiers(alias :: Nil))
}

/**
* Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences
* this allows for optional case insensitive attribute resolution. This node can be elided after
* analysis.
*/
case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
protected def lowerCaseSchema(dataType: DataType): DataType = dataType match {
case StructType(fields) =>
StructType(fields.map(f =>
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
case otherType => otherType
}

override val output = child.output.map {
case a: AttributeReference =>
AttributeReference(
a.name.toLowerCase,
lowerCaseSchema(a.dataType),
a.nullable)(
a.exprId,
a.qualifiers)
case other => other
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
catalog.registerTable(None, tableName, rdd.queryExecution.analyzed)
catalog.registerTable(None, tableName, rdd.queryExecution.logical)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("SPARK-3349 partitioning after limit") {
/*
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2)
.registerTempTable("subset1")
Expand All @@ -395,7 +394,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
(1, "a", 1) ::
(2, "b", 2) :: Nil)
*/
}

test("mixed-case keywords") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {

LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias))
}
}
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog

// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case p @ InsertIntoTable(
LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
castChildOutput(p, table, child)

case p @ logical.InsertIntoTable(
LowerCaseSchema(
InMemoryRelation(_, _, _,
HiveTableScan(_, table, _))), _, child, _) =>
HiveTableScan(_, table, _)), _, child, _) =>
castChildOutput(p, table, child)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.StringType
import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
Expand Down Expand Up @@ -55,7 +55,7 @@ private[hive] trait HiveStrategies {
object ParquetConversion extends Strategy {
implicit class LogicalPlanHacks(s: SchemaRDD) {
def lowerCase =
new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan))
new SchemaRDD(s.sqlContext, s.logicalPlan)

def addPartitioningAttributes(attrs: Seq[Attribute]) =
new SchemaRDD(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LowerCaseSchema
import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.MetastoreRelation
Expand Down Expand Up @@ -52,8 +51,7 @@ case class CreateTableAsSelect(
sc.catalog.createTable(database, tableName, query.output, false)
// Get the Metastore Relation
sc.catalog.lookupRelation(Some(database), tableName, None) match {
case LowerCaseSchema(r: MetastoreRelation) => r
case o: MetastoreRelation => o
case r: MetastoreRelation => r
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ private[hive] abstract class HiveFunctionRegistry
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
sys.error(s"Couldn't find function $name"))
val functionInfo: FunctionInfo =
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
sys.error(s"Couldn't find function $name"))

val functionClassName = functionInfo.getFunctionClass.getName()
val functionClassName = functionInfo.getFunctionClass.getName

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ class HiveResolutionSuite extends HiveComparisonTest {
.registerTempTable("caseSensitivityTest")

sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
}

println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution)

sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect()

// TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a")
ignore("case insensitivity with scala reflection joins") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this test case is ignored?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, a.a and b.a cause ambiguous attribute references here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I only changed it to an explicit ignore instead of being commented out. We need to decided if this is allowed or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hive 0.12 actually supports this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what is the exception that is thrown, will #2209 fix this? Otherwise we should open a JIRA.

// Test resolution with Scala Reflection
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("caseSensitivityTest")

sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}

test("nested repeated resolution") {
Expand Down