Skip to content

Commit

Permalink
support dot notation on array of struct
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Feb 7, 2015
1 parent 1390e56 commit 08a228a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val actualField = fields.filter(f => resolver(f.name, fieldName))
if (actualField.length == 0) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
val field = actualField(0)
GetField(expr, field, fields.indexOf(field))
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
}
}


trait GetField extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
override def foldable = child.foldable
override def toString = s"$child.${field.name}"

def field: StructField
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
}

override def toString = s"$child.${field.name}"
/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends GetField {

def dataType = ArrayType(field.dataType, containsNull)
override def nullable = child.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetField(expr, field, fields.indexOf(field))
StructGetField(expr, field, fields.indexOf(field))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,21 +342,19 @@ class JsonSuite extends QueryTest {
)
}

ignore("Complex field and type inferring (Ignored)") {
test("GetField operation on complex data type") {
val jsonDF = jsonRDD(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")

// Right now, "field1" and "field2" are treated as aliases. We should fix it.
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
Row(true, "str1")
)

// Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
// Getting all values of a specific field from an array of structs.
checkAnswer(
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
Row(Seq(true, false), Seq("str1", null))
Row(Seq(true, false, null), Seq("str1", null, null))
)
}

Expand Down

0 comments on commit 08a228a

Please sign in to comment.