diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 574d96d92942b..71810b798bd04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean) extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { + val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution + // TODO: pass this in as a parameter. val fixedPoint = FixedPoint(100) @@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool lazy val batches: Seq[Batch] = Seq( Batch("MultiInstanceRelations", Once, NewRelationInstances), - Batch("CaseInsensitiveAttributeReferences", Once, - (if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*), Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: @@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } - /** - * Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase. - */ - object LowercaseAttributeReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case UnresolvedRelation(databaseName, name, alias) => - UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase)) - case Subquery(alias, child) => Subquery(alias.toLowerCase, child) - case q: LogicalPlan => q transformExpressions { - case s: Star => s.copy(table = s.table.map(_.toLowerCase)) - case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase) - case Alias(c, name) => Alias(c, name.toLowerCase)() - case GetField(c, name) => GetField(c, name.toLowerCase) - } - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's @@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name).getOrElse(u) + val result = q.resolveChildren(name, resolver).getOrElse(u) logDebug(s"Resolving $u to $result") result } @@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolveChildren) + val resolved = unresolved.flatMap(child.resolve(_, resolver)) val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val missingInProject = requiredAttributes -- p.output @@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Sort(ordering, Project(projectList ++ missingInProject, child))) } else { + logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => @@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve) + val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) val missingInAggs = resolved.filterNot(a.outputSet.contains) logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") if (missingInAggs.nonEmpty) { @@ -258,14 +242,14 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { - case s: Star => s.expand(child.output) + case s: Star => s.expand(child.output, resolver) case o => o :: Nil }, child) case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { - case s: Star => s.expand(t.child.output) + case s: Star => s.expand(t.child.output, resolver) case o => o :: Nil } ) @@ -273,7 +257,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case a: Aggregate if containsStar(a.aggregateExpressions) => a.copy( aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output) + case s: Star => s.expand(a.child.output, resolver) case o => o :: Nil } ) @@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool /** * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are * only required to provide scoping information for attributes and can be removed once analysis is - * complete. Similarly, this node also removes - * [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators. + * complete. */ object EliminateAnalysisOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Subquery(_, child) => child - case LowerCaseSchema(child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 9f37ca904ffeb..3f672a3e0fd91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -22,4 +22,14 @@ package org.apache.spark.sql.catalyst * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s * into fully typed objects using information in a schema [[Catalog]]. */ -package object analysis +package object analysis { + + /** + * Responsible for resolving which identifiers refer to the same entity. For example, by using + * case insensitive equality. + */ + type Resolver = (String, String) => Boolean + + val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b) + val caseSensitiveResolution = (a: String, b: String) => a == b +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a2c61c65487cb..67570a6f73c36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -54,6 +54,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this + override def withName(newName: String) = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -97,13 +98,14 @@ case class Star( override def newInstance = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this + override def withName(newName: String) = this - def expand(input: Seq[Attribute]): Seq[NamedExpression] = { + def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { // If there is no table specified, use all input attributes. case None => input // If there is a table, pick out attributes that are part of this table. - case Some(t) => input.filter(_.qualifiers contains t) + case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) } val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { case (n: NamedExpression, _) => n diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7c4b9d4847e26..59fb0311a9c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -59,6 +59,7 @@ abstract class Attribute extends NamedExpression { def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute + def withName(newName: String): Attribute def toAttribute = this def newInstance: Attribute @@ -86,7 +87,6 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable - override def toAttribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) @@ -144,6 +144,14 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea } } + override def withName(newName: String): AttributeReference = { + if (name == newName) { + this + } else { + AttributeReference(newName, dataType, nullable)(exprId, qualifiers) + } + } + /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ede431ad4ab27..28d863e58beca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees -abstract class LogicalPlan extends QueryPlan[LogicalPlan] { +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => /** @@ -75,20 +77,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String): Option[NamedExpression] = - resolve(name, children.flatMap(_.output)) + def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = + resolve(name, children.flatMap(_.output), resolver) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = - resolve(name, output) + def resolve(name: String, resolver: Resolver): Option[NamedExpression] = + resolve(name, output, resolver) /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { + protected def resolve( + name: String, + input: Seq[Attribute], + resolver: Resolver): Option[NamedExpression] = { + val parts = name.split("\\.") + // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to @@ -96,21 +103,69 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = - if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts - if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) { + parts.drop(1) + } else { + parts + } + + if (resolver(option.name, remainingParts.head)) { + // Preserve the case of the user's attribute reference. + (option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil + } else { + Nil + } } options.distinct match { - case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. + // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) + // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - case Seq() => None // No matches. + val aliased = + Alias( + resolveNesting(nestedFields, a, resolver), + nestedFields.last)() // Preserve the case of the user's field access. + Some(aliased) + + // No matches. + case Seq() => + logTrace(s"Could not find $name in ${input.mkString(", ")}") + None + + // More than one match. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") } } + + /** + * Given a list of successive nested field accesses, and a based expression, attempt to resolve + * the actual field lookups on this expression. + */ + private def resolveNesting( + nestedFields: List[String], + expression: Expression, + resolver: Resolver): Expression = { + + (nestedFields, expression.dataType) match { + case (Nil, _) => expression + case (requestedField :: rest, StructType(fields)) => + val actualField = fields.filter(f => resolver(f.name, requestedField)) + actualField match { + case Seq() => + sys.error( + s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") + case Seq(singleMatch) => + resolveNesting(rest, GetField(expression, singleMatch.name), resolver) + case multipleMatches => + sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}") + } + case (_, dt) => sys.error(s"Can't access nested field in type $dt") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8e8259cae6670..391508279bb80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -165,32 +165,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output = child.output.map(_.withQualifiers(alias :: Nil)) } -/** - * Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences - * this allows for optional case insensitive attribute resolution. This node can be elided after - * analysis. - */ -case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { - protected def lowerCaseSchema(dataType: DataType): DataType = dataType match { - case StructType(fields) => - StructType(fields.map(f => - StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) - case otherType => otherType - } - - override val output = child.output.map { - case a: AttributeReference => - AttributeReference( - a.name.toLowerCase, - lowerCaseSchema(a.dataType), - a.nullable)( - a.exprId, - a.qualifiers) - case other => other - } -} - case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7dbaf7faff0c0..b245e1a863cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.queryExecution.analyzed) + catalog.registerTable(None, tableName, rdd.queryExecution.logical) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 15f6bcef93886..08376eb5e5c4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -381,7 +381,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3349 partitioning after limit") { - /* sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") @@ -396,7 +395,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), (1, "a", 1) :: (2, "b", 2) :: Nil) - */ } test("mixed-case keywords") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index e0be09e6793ea..3e1a7b71528e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { - override def lookupRelation( - databaseName: Option[String], - tableName: String, - alias: Option[String] = None): LogicalPlan = { - - LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias)) - } - } + override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog // Note that HiveUDFs will be overridden by functions registered in this context. @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2c0db9be57e54..6b4399e852c7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -129,14 +129,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable( - LowerCaseSchema(table: MetastoreRelation), _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - LowerCaseSchema( InMemoryRelation(_, _, _, - HiveTableScan(_, table, _))), _, child, _) => + HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 43dd3d234f73a..8ac17f37201a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} @@ -55,7 +55,7 @@ private[hive] trait HiveStrategies { object ParquetConversion extends Strategy { implicit class LogicalPlanHacks(s: SchemaRDD) { def lowerCase = - new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + new SchemaRDD(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = new SchemaRDD( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 71ea774d77795..1017fe6d5396d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -21,7 +21,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LowerCaseSchema import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.MetastoreRelation @@ -52,8 +51,7 @@ case class CreateTableAsSelect( sc.catalog.createTable(database, tableName, query.output, false) // Get the Metastore Relation sc.catalog.lookupRelation(Some(database), tableName, None) match { - case LowerCaseSchema(r: MetastoreRelation) => r - case o: MetastoreRelation => o + case r: MetastoreRelation => r } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 5a0e6c5cc1bba..19ff3b66ad7ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -44,10 +44,11 @@ private[hive] abstract class HiveFunctionRegistry def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. - val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( - sys.error(s"Couldn't find function $name")) + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $name")) - val functionClassName = functionInfo.getFunctionClass.getName() + val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] diff --git a/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 b/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b6be6bc1bfefe..ee9d08ff75450 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -36,6 +36,9 @@ class HiveResolutionSuite extends HiveComparisonTest { createQueryTest("database.table table.attr", "SELECT src.key FROM default.src ORDER BY key LIMIT 1") + createQueryTest("database.table table.attr case insensitive", + "SELECT SRC.Key FROM Default.Src ORDER BY key LIMIT 1") + createQueryTest("alias.attr", "SELECT a.key FROM src a ORDER BY key LIMIT 1") @@ -56,14 +59,18 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerTempTable("caseSensitivityTest") - sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") - - println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution) - - sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect() + val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), + "The output schema did not preserve the case of the query.") + query.collect() + } - // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a") + ignore("case insensitivity with scala reflection joins") { + // Test resolution with Scala Reflection + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .registerTempTable("caseSensitivityTest") + sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") {