Skip to content

Commit

Permalink
Correctly check case sensitivity in GetField
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Oct 28, 2014
1 parent 27470d3 commit eb1fffe
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class SqlParser extends AbstractSparkSQLParser {
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => GetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
NewRelationInstances),
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveGetField ::
ResolveRelations ::
ResolveSortReferences ::
NewRelationInstances ::
Expand Down Expand Up @@ -165,6 +166,19 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Replaces [[UnresolvedGetField]]s with concrete [[GetField]]
*/
object ResolveGetField extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan if q.childrenResolved =>
q transformExpressionsUp {
case u @ UnresolvedGetField(child, fieldName) if child.resolved =>
GetField(u.child, u.fieldName, resolver)
}
}
}

/**
* In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ case class Star(

override def toString = table.map(_ + ".").getOrElse("") + "*"
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"$child.$fieldName"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -96,7 +96,7 @@ package object dsl {
def isNotNull = IsNotNull(expr)

def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)
def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)

def cast(to: DataType) = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField

import scala.collection.Map

import org.apache.spark.sql.catalyst.types._
Expand Down Expand Up @@ -73,33 +75,38 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}

lazy val field =
structType.fields
.find(_.name == fieldName)
.getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))

lazy val ordinal = structType.fields.indexOf(field)

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType]

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}

override def toString = s"$child.$fieldName"
override def toString = s"$child.${field.name}"
}

object GetField {
def apply(
e: Expression,
fieldName: String,
equality: (String, String) => Boolean = _ == _): GetField = {
val structType = e.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
val field = structType.fields
.find(f => equality(f.name, fieldName))
.getOrElse(sys.error(s"No such field $fieldName in ${e.dataType}"))
val ordinal = structType.fields.indexOf(field)
GetField(e, field, ordinal)
}

def apply(ug: UnresolvedGetField): GetField = GetField(ug.child, ug.fieldName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedGetField}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode

/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
Expand Down Expand Up @@ -160,11 +159,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
val aliased =
Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)
Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())

// No matches.
case Seq() =>
Expand All @@ -177,32 +172,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
actualField match {
case Seq() =>
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
case Seq(singleMatch) =>
resolveNesting(rest, GetField(expression, singleMatch.name), resolver)
case multipleMatches =>
sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class ExpressionEvaluationSuite extends FunSuite {

checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
checkEvaluation(GetField('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("arithmetic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest {

GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
GetField(
UnresolvedGetField(
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ private[hive] object HiveQl {
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
case other => GetField(other, attr)
case other => UnresolvedGetField(other, attr)
}

/* Stars (*) */
Expand Down

0 comments on commit eb1fffe

Please sign in to comment.