Skip to content

Commit

Permalink
[SPARK-7133] [SQL] Implement struct, array, and map field accessor
Browse files Browse the repository at this point in the history
It's the first step: generalize UnresolvedGetField to support all map, struct, and array
TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?).

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#5744 from cloud-fan/generalize and squashes the following commits:

715c589 [Wenchen Fan] address comments
7ea5b31 [Wenchen Fan] fix python test
4f0833a [Wenchen Fan] add python test
f515d69 [Wenchen Fan] add apply method and test cases
8df6199 [Wenchen Fan] fix python test
239730c [Wenchen Fan] fix test compile
2a70526 [Wenchen Fan] use _bin_op in dataframe.py
6bf72bc [Wenchen Fan] address comments
3f880c3 [Wenchen Fan] add java doc
ab35ab5 [Wenchen Fan] fix python test
b5961a9 [Wenchen Fan] fix style
c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array
  • Loading branch information
cloud-fan authored and nemccarthy committed Jun 19, 2015
1 parent 462ea6f commit f39969f
Show file tree
Hide file tree
Showing 16 changed files with 327 additions and 191 deletions.
24 changes: 12 additions & 12 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def __init__(self, jc):

# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")
__getitem__ = _bin_op("apply")

# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
Expand Down Expand Up @@ -1308,19 +1308,19 @@ def getField(self, name):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
+---+
|r.b|
+---+
| b|
+---+
+----+
|r[b]|
+----+
| b|
+----+
>>> df.select(df.r.a).show()
+---+
|r.a|
+---+
| 1|
+---+
+----+
|r[a]|
+----+
| 1|
+----+
"""
return Column(self._jc.getField(name))
return self[name]

def __getattr__(self, item):
if item.startswith("__"):
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ def test_access_nested_types(self):
self.assertEqual("v", df.select(df.d["k"]).first()[0])
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])

def test_field_accessor(self):
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
self.assertEqual(1, df.select(df.l[0]).first()[0])
self.assertEqual(1, df.select(df.r["a"]).first()[0])
self.assertEqual("b", df.select(df.r["b"]).first()[0])
self.assertEqual("v", df.select(df.d["k"]).first()[0])

def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
{ case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ class Analyzer(
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
GetField(child, fieldName, resolver)
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,17 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
/**
* Extracts a value or values from an Expression
*
* @param child The expression to extract value from,
* can be Map, Array, Struct or array of Structs.
* @param extraction The expression to describe the extraction,
* can be key of Map, index of Array, field name of Struct.
*/
case class UnresolvedExtractValue(child: Expression, extraction: Expression)
extends UnaryExpression {

override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
Expand All @@ -193,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString: String = s"$child.$fieldName"
override def toString: String = s"$child[$extraction]"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -100,8 +100,9 @@ package object dsl {
def isNull: Predicate = IsNull(expr)
def isNotNull: Predicate = IsNotNull(expr)

def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal)
def getField(fieldName: String): UnresolvedExtractValue =
UnresolvedExtractValue(expr, Literal(fieldName))

def cast(to: DataType): Expression = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.Map

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._

object ExtractValue {
/**
* Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
* depend on the type of `child` and `extraction`.
*
* `child` | `extraction` | concrete `ExtractValue`
* ----------------------------------------------------------------
* Struct | Literal String | GetStructField
* Array[Struct] | Literal String | GetArrayStructFields
* Array | Integral type | GetArrayItem
* Map | Any type | GetMapValue
*/
def apply(
child: Expression,
extraction: Expression,
resolver: Resolver): ExtractValue = {

(child.dataType, extraction) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetStructField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
case (_: MapType, _) =>
GetMapValue(child, extraction)
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
s"Field name should be String Literal, but it's $extraction"
case _: ArrayType =>
s"Array index should be integral type, but it's ${extraction.dataType}"
case other =>
s"Can't extract value from $child"
}
throw new AnalysisException(errorMsg)
}
}

def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
g match {
case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
case _ => Some((g.child, null))
}
}

/**
* Find the ordinal of StructField, report error if no desired field or over one
* desired fields are found.
*/
private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
}

trait ExtractValue extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends ExtractValue {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
}

/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
containsNull: Boolean) extends ExtractValue {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
}
}
}

abstract class ExtractValueWithOrdinal extends ExtractValue {
self: Product =>

def ordinal: Expression

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def foldable: Boolean = child.foldable && ordinal.foldable
override def toString: String = s"$child[$ordinal]"
override def children: Seq[Expression] = child :: ordinal :: Nil

override def eval(input: Row): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
val o = ordinal.eval(input)
if (o == null) {
null
} else {
evalNotNull(value, o)
}
}
}

protected def evalNotNull(value: Any, ordinal: Any): Any
}

/**
* Returns the field at `ordinal` in the Array `child`
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends ExtractValueWithOrdinal {

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

override lazy val resolved = childrenResolved &&
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]

protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
val index = ordinal.asInstanceOf[Int]
if (index >= baseValue.size || index < 0) {
null
} else {
baseValue(index)
}
}
}

/**
* Returns the value of key `ordinal` in Map `child`
*/
case class GetMapValue(child: Expression, ordinal: Expression)
extends ExtractValueWithOrdinal {

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]

protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}
}
Loading

0 comments on commit f39969f

Please sign in to comment.