Skip to content

Commit

Permalink
[SQL] Better error messages for analysis failures
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Feb 12, 2015
1 parent a38e23c commit 6197cd5
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
import org.apache.spark.sql.types._

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -66,9 +66,7 @@ class Analyzer(catalog: Catalog,
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution ::
CheckAggregation ::
Nil: _*),
CheckResolution),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
Expand All @@ -77,21 +75,70 @@ class Analyzer(catalog: Catalog,
* Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }

def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case p if p.expressions.exists(!_.resolved) =>
val missing = p.expressions.filterNot(_.resolved).map(_.prettyString).mkString(",")
val from = p.inputSet.map(_.name).mkString("{", ", ", "}")

throw new AnalysisException(s"Cannot resolve '$missing' given input columns $from")
case p if !p.resolved && p.childrenResolved =>
throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
} match {
// As a backstop, use the root node to check that the entire plan tree is resolved.
case p if !p.resolved =>
throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
case p => p
plan.foreachUp {
case operator: LogicalPlan =>
operator transformAllExpressions {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString("{", ", ", "}")
failAnalysis(s"cannot resolve '$a' given input columns $from")

case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")


}

operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(s"filter expression '${f.condition.prettyString}' is not a boolean.")

case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.find { e =>
!isValidAggregateExpression(e.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g: GetField, _) => g
})
}.foreach { e =>
failAnalysis(s"expression must be aggregates or be in group by $e")
}

aggregatePlan

case o if o.children.nonEmpty && !o.references.subsetOf(o.inputSet) =>
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
val input = o.inputSet.map(_.prettyString).mkString(",")

failAnalysis(s"resolved attributes $missingAttributes missing from $input")

// Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case _ => // Analysis successful!
}
}

plan
}
}

Expand Down Expand Up @@ -192,37 +239,6 @@ class Analyzer(catalog: Catalog,
}
}

/**
* Checks for non-aggregated attributes with aggregation
*/
object CheckAggregation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.find { e =>
!isValidAggregateExpression(e.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g: GetField, _) => g
})
}.foreach { e =>
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}

aggregatePlan
}
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.Star

protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a.exprId.hashCode()
override def hashCode() = a match {
case ar: AttributeReference => ar.exprId.hashCode()
case a => a.hashCode()
}

override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
children.foreach(_.foreach(f))
}

/**
* Runs the given function recursively on [[children]] then on this node.
* @param f the function to be applied to each node in the tree.
*/
def foreachUp(f: BaseType => Unit): Unit = {
children.foreach(_.foreach(f))
f(this)
}

/**
* Returns a Seq containing the result of applying the given function to each
* node in this tree in a preorder traversal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Literal, Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -108,24 +108,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
testRelation)
}

test("throw errors for unresolved attributes during analysis") {
val e = intercept[AnalysisException] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
def errorTest(
name: String,
plan: LogicalPlan,
errorMessages: Seq[String],
caseSensitive: Boolean = true) = {
test(name) {
val error = intercept[AnalysisException] {
if(caseSensitive) {
caseSensitiveAnalyze(plan)
} else {
caseInsensitiveAnalyze(plan)
}
}

errorMessages.foreach(m => assert(error.getMessage contains m))
}
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
}

test("throw errors for unresolved plans during analysis") {
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}
val e = intercept[AnalysisException] {
caseSensitiveAnalyze(UnresolvedTestPlan())
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)

errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)

case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}

errorTest(
"catch all unresolved plan",
UnresolvedTestPlan(),
"unresolved" :: Nil)


test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
Expand All @@ -134,18 +155,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())

val expr0 = 'a / 2
val expr1 = 'a / 'b
val expr2 = 'a / 'c
val expr3 = 'a / 'd
val expr4 = 'e / 'e
val plan = caseInsensitiveAnalyze(Project(
Alias(expr0, s"Analyzer($expr0)")() ::
Alias(expr1, s"Analyzer($expr1)")() ::
Alias(expr2, s"Analyzer($expr2)")() ::
Alias(expr3, s"Analyzer($expr3)")() ::
Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
val plan = caseInsensitiveAnalyze(
testRelation2.select(
'a / Literal(2) as 'div1,
'a / 'b as 'div2,
'a / 'c as 'div3,
'a / 'd as 'div4,
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList

assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
Expand Down

0 comments on commit 6197cd5

Please sign in to comment.