Skip to content

Commit

Permalink
[SPARK-5278][SQL] Introduce UnresolvedGetField and complete the check…
Browse files Browse the repository at this point in the history
… of ambiguous reference to fields

When the `GetField` chain(`a.b.c.d.....`) is interrupted by `GetItem` like `a.b[0].c.d....`, then the check of ambiguous reference to fields is broken.
The reason is that: for something like `a.b[0].c.d`, we first parse it to `GetField(GetField(GetItem(Unresolved("a.b"), 0), "c"), "d")`. Then in `LogicalPlan#resolve`, we resolve `"a.b"` and build a `GetField` chain from bottom(the relation). But for the 2 outer `GetFiled`, we have to resolve them in `Analyzer` or do it in `GetField` lazily, check data type of child, search needed field, etc. which is similar to what we have done in `LogicalPlan#resolve`.
So in this PR, the fix is just copy the same logic in `LogicalPlan#resolve` to `Analyzer`, which is simple and quick, but I do suggest introduce `UnresolvedGetFiled` like I explained in #2405.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #4068 from cloud-fan/simple and squashes the following commits:

a6857b5 [Wenchen Fan] fix import order
8411c40 [Wenchen Fan] use UnresolvedGetField
  • Loading branch information
cloud-fan authored and marmbrus committed Feb 6, 2015
1 parent bc36356 commit 4793c84
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ class SqlParser extends AbstractSparkSQLParser {
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => GetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class Analyzer(catalog: Catalog,

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
q transformExpressionsUp {
case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
Expand All @@ -295,15 +295,8 @@ class Analyzer(catalog: Catalog,
val result = q.resolveChildren(name, resolver).getOrElse(u)
logDebug(s"Resolving $u to $result")
result

// Resolve field names using the resolver.
case f @ GetField(child, fieldName) if !f.resolved && child.resolved =>
child.dataType match {
case StructType(fields) =>
val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName))
resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f)
case _ => f
}
case UnresolvedGetField(child, fieldName) if child.resolved =>
resolveGetField(child, fieldName)
}
}

Expand All @@ -312,6 +305,27 @@ class Analyzer(catalog: Catalog,
*/
protected def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
expr.dataType match {
case StructType(fields) =>
val actualField = fields.filter(f => resolver(f.name, fieldName))
if (actualField.length == 0) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
val field = actualField(0)
GetField(expr, field, fields.indexOf(field))
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,15 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"$child.$fieldName"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -101,7 +101,7 @@ package object dsl {
def isNotNull = IsNotNull(expr)

def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)
def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)

def cast(to: DataType) = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,19 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}

lazy val field =
structType.fields
.find(_.name == fieldName)
.getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))

lazy val ordinal = structType.fields.indexOf(field)

override lazy val resolved = childrenResolved && fieldResolved

/** Returns true only if the fieldName is found in the child struct. */
private def fieldResolved = child.dataType match {
case StructType(fields) => fields.map(_.name).contains(fieldName)
case _ => false
}

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}

override def toString = s"$child.$fieldName"
override def toString = s"$child.${field.name}"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
Expand Down Expand Up @@ -160,11 +160,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
val aliased =
Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)
Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())

// No matches.
case Seq() =>
Expand All @@ -177,31 +173,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
if (actualField.length == 0) {
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
resolveNesting(rest, GetField(expression, actualField(0).name), resolver)
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.scalatest.FunSuite
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -846,23 +847,33 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
def quickBuildGetField(expr: Expression, fieldName: String) = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetField(expr, field, fields.indexOf(field))
}
}

def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName)

checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row)

val typeS_notNullable = StructType(
StructField("a", StringType, nullable = false)
:: StructField("b", StringType, nullable = false) :: Nil
)

assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)

assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true)
assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true)

checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("arithmetic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest {

GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
GetField(
UnresolvedGetField(
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.Dsl.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -505,7 +506,7 @@ trait Column extends DataFrame {
/**
* An expression that gets a field by name in a [[StructField]].
*/
def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName))
def getField(fieldName: String): Column = exprToColumn(UnresolvedGetField(expr, fieldName))

/**
* An expression that returns a substring.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
case other => GetField(other, attr)
case other => UnresolvedGetField(other, attr)
}

/* Stars (*) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, sql}
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._

case class Nested(a: Int, B: Int)
Expand All @@ -29,16 +28,24 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
*/
class HiveResolutionSuite extends HiveComparisonTest {

case class NestedData(a: Seq[NestedData2], B: NestedData2)
case class NestedData2(a: NestedData3, B: NestedData3)
case class NestedData3(a: Int, B: Int)

test("SPARK-3698: case insensitive test for nested data") {
sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested")
jsonRDD(sparkContext.makeRDD(
"""{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested")
// This should be successfully analyzed
sql("SELECT a[0].A.A from nested").queryExecution.analyzed
}

test("SPARK-5278: check ambiguous reference to fields") {
jsonRDD(sparkContext.makeRDD(
"""{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested")

// there are 2 filed matching field name "b", we should report Ambiguous reference error
val exception = intercept[RuntimeException] {
sql("SELECT a[0].b from nested").queryExecution.analyzed
}
assert(exception.getMessage.contains("Ambiguous reference to fields"))
}

createQueryTest("table.attr",
"SELECT src.key FROM src ORDER BY key LIMIT 1")

Expand Down Expand Up @@ -68,7 +75,7 @@ class HiveResolutionSuite extends HiveComparisonTest {

test("case insensitivity with scala reflection") {
// Test resolution with Scala Reflection
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("caseSensitivityTest")

val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
Expand All @@ -79,14 +86,14 @@ class HiveResolutionSuite extends HiveComparisonTest {

ignore("case insensitivity with scala reflection joins") {
// Test resolution with Scala Reflection
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("caseSensitivityTest")

sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}

test("nested repeated resolution") {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("nestedRepeatedTest")
assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1)
}
Expand Down

0 comments on commit 4793c84

Please sign in to comment.