From eb1fffe5a49901a3c5f3ac22716bfd5e52acde74 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 Oct 2014 18:27:00 +0800 Subject: [PATCH] Correctly check case sensitivity in GetField --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 14 +++++++ .../sql/catalyst/analysis/unresolved.scala | 12 ++++++ .../spark/sql/catalyst/dsl/package.scala | 4 +- .../catalyst/expressions/complexTypes.scala | 39 +++++++++++-------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 37 ++---------------- .../ExpressionEvaluationSuite.scala | 2 +- .../optimizer/ConstantFoldingSuite.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- 10 files changed, 60 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 4e967713ede64..fb09a0771c849 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -363,7 +363,7 @@ class SqlParser extends AbstractSparkSQLParser { | expression ~ ("[" ~> expression <~ "]") ^^ { case base ~ ordinal => GetItem(base, ordinal) } | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => GetField(base, fieldName) } + { case base ~ fieldName => UnresolvedGetField(base, fieldName) } | cast | "(" ~> expression <~ ")" | function diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a448c794213ae..19e89a99b6508 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool NewRelationInstances), Batch("Resolution", fixedPoint, ResolveReferences :: + ResolveGetField :: ResolveRelations :: ResolveSortReferences :: NewRelationInstances :: @@ -165,6 +166,19 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Replaces [[UnresolvedGetField]]s with concrete [[GetField]] + */ + object ResolveGetField extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case q: LogicalPlan if q.childrenResolved => + q transformExpressionsUp { + case u @ UnresolvedGetField(child, fieldName) if child.resolved => + GetField(u.child, u.fieldName, resolver) + } + } + } + /** * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 77d84e1687e1b..9ca0def65a272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -121,3 +121,15 @@ case class Star( override def toString = table.map(_ + ".").getOrElse("") + "*" } + +case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { + override def dataType = throw new UnresolvedException(this, "dataType") + override def foldable = throw new UnresolvedException(this, "foldable") + override def nullable = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + override def toString = s"$child.$fieldName" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75b6e37c2a1f9..ce39a204ee3c8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -96,7 +96,7 @@ package object dsl { def isNotNull = IsNotNull(expr) def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = GetField(expr, fieldName) + def getField(fieldName: String) = UnresolvedGetField(expr, fieldName) def cast(to: DataType) = Cast(expr, to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 19421e5667138..a316d20540baa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField + import scala.collection.Map import org.apache.spark.sql.catalyst.types._ @@ -73,33 +75,38 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { /** * Returns the value of fields in the Struct `child`. */ -case class GetField(child: Expression, fieldName: String) extends UnaryExpression { +case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { type EvaluatedType = Any def dataType = field.dataType override def nullable = child.nullable || field.nullable override def foldable = child.foldable - protected def structType = child.dataType match { - case s: StructType => s - case otherType => sys.error(s"GetField is not valid on fields of type $otherType") - } - - lazy val field = - structType.fields - .find(_.name == fieldName) - .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) - - lazy val ordinal = structType.fields.indexOf(field) - - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] - override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } - override def toString = s"$child.$fieldName" + override def toString = s"$child.${field.name}" +} + +object GetField { + def apply( + e: Expression, + fieldName: String, + equality: (String, String) => Boolean = _ == _): GetField = { + val structType = e.dataType match { + case s: StructType => s + case otherType => sys.error(s"GetField is not valid on fields of type $otherType") + } + val field = structType.fields + .find(f => equality(f.name, fieldName)) + .getOrElse(sys.error(s"No such field $fieldName in ${e.dataType}")) + val ordinal = structType.fields.indexOf(field) + GetField(e, field, ordinal) + } + + def apply(ug: UnresolvedGetField): GetField = GetField(ug.child, ug.fieldName) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ce7c78195830..a2a523609af9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) + case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ed578e081be73..fead0bac420e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedGetField} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.trees.TreeNode /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the @@ -160,11 +159,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - val aliased = - Alias( - resolveNesting(nestedFields, a, resolver), - nestedFields.last)() // Preserve the case of the user's field access. - Some(aliased) + Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)()) // No matches. case Seq() => @@ -177,32 +172,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") } } - - /** - * Given a list of successive nested field accesses, and a based expression, attempt to resolve - * the actual field lookups on this expression. - */ - private def resolveNesting( - nestedFields: List[String], - expression: Expression, - resolver: Resolver): Expression = { - - (nestedFields, expression.dataType) match { - case (Nil, _) => expression - case (requestedField :: rest, StructType(fields)) => - val actualField = fields.filter(f => resolver(f.name, requestedField)) - actualField match { - case Seq() => - sys.error( - s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") - case Seq(singleMatch) => - resolveNesting(rest, GetField(expression, singleMatch.name), resolver) - case multipleMatches => - sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}") - } - case (_, dt) => sys.error(s"Can't access nested field in type $dt") - } - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 53c53481f984e..d4bbb66e7c4df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -507,7 +507,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) - checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) + checkEvaluation(GetField('c.struct(typeS).at(2).getField("a")), "aa", row) } test("arithmetic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 0a27cce337482..c5d97aacdb539 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest { GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, - GetField( + UnresolvedGetField( Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), "a") as 'c5, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 54c619722ee12..4902cb1ac80b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -895,7 +895,7 @@ private[hive] object HiveQl { nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - case other => GetField(other, attr) + case other => UnresolvedGetField(other, attr) } /* Stars (*) */