Skip to content

Commit

Permalink
[SQL] Support df("*") to select all columns in a data frame.
Browse files Browse the repository at this point in the history
This PR makes Star a trait, and provides two implementations: UnresolvedStar (used for *, tblName.*) and ResolvedStar (used for df("*")).

Author: Reynold Xin <rxin@databricks.com>

Closes apache#4283 from rxin/df-star and squashes the following commits:

c9cba3e [Reynold Xin] Removed mapFunction in UnresolvedStar.
1a3a1d7 [Reynold Xin] [SQL] Support df("*") to select all columns in a data frame.
  • Loading branch information
rxin committed Jan 30, 2015
1 parent 22271f9 commit 80def9d
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ class SqlParser extends AbstractSparkSQLParser {
)

protected lazy val baseExpression: Parser[Expression] =
( "*" ^^^ Star(None)
( "*" ^^^ UnresolvedStar(None)
| primary
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false

override def newInstance = this
override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = UnresolvedAttribute(name)
Expand All @@ -77,15 +77,10 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E

/**
* Represents all of the input attributes to a given relational operator, for example in
* "SELECT * FROM ...".
*
* @param table an optional table that should be the target of the expansion. If omitted all
* tables' columns are produced.
* "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
*/
case class Star(
table: Option[String],
mapFunction: Attribute => Expression = identity[Attribute])
extends Attribute with trees.LeafNode[Expression] {
trait Star extends Attribute with trees.LeafNode[Expression] {
self: Product =>

override def name = throw new UnresolvedException(this, "name")
override def exprId = throw new UnresolvedException(this, "exprId")
Expand All @@ -94,29 +89,53 @@ case class Star(
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false

override def newInstance = this
override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = this

def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
// Star gets expanded at runtime so we never evaluate a Star.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
}


/**
* Represents all of the input attributes to a given relational operator, for example in
* "SELECT * FROM ...".
*
* @param table an optional table that should be the target of the expansion. If omitted all
* tables' columns are produced.
*/
case class UnresolvedStar(table: Option[String]) extends Star {

override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
// If there is no table specified, use all input attributes.
case None => input
// If there is a table, pick out attributes that are part of this table.
case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
}
val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
expandedAttributes.zip(input).map {
case (n: NamedExpression, _) => n
case (e, originalAttribute) =>
Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
}
mappedAttributes
}

// Star gets expanded at runtime so we never evaluate a Star.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = table.map(_ + ".").getOrElse("") + "*"
}


/**
* Represents all the resolved input attributes to a given relational operator. This is used
* in the data frame DSL.
*
* @param expressions Expressions to expand.
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
test("union project *") {
val plan = (1 to 100)
.map(_ => testRelation)
.fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None))))
.fold[LogicalPlan](testRelation) { (a, b) =>
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}

assert(caseInsensitiveAnalyze(plan).resolved)
}
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.language.implicitConversions

import org.apache.spark.sql.Dsl.lit
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -71,8 +71,8 @@ class Column(
* - "df.*" becomes an expression selecting all columns in data frame "df".
*/
def this(name: String) = this(name match {
case "*" => Star(None)
case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2)))
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
case _ => UnresolvedAttribute(name)
})

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -265,7 +265,7 @@ class DataFrame protected[sql](
*/
override def apply(colName: String): Column = colName match {
case "*" =>
Column("*")
new Column(ResolvedStar(schema.fieldNames.map(resolve)))
case _ =>
val expr = resolve(colName)
new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}

ignore("star qualified by data frame object") {
test("star qualified by data frame object") {
// This is not yet supported.
val df = testData.toDataFrame
checkAnswer(df.select(df("*")), df.collect().toSeq)
val goldAnswer = df.collect().toSeq
checkAnswer(df.select(df("*")), goldAnswer)

val df1 = df.select(df("*"), lit("abcd").as("litCol"))
checkAnswer(df1.select(df("*")), goldAnswer)
}

test("star qualified by table name") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}

/* Stars (*) */
case Token("TOK_ALLCOLREF", Nil) => Star(None)
case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
Star(Some(name))
UnresolvedStar(Some(name))

/* Aggregate Functions */
case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
Expand Down Expand Up @@ -1145,7 +1145,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
UnresolvedFunction(name, args.map(nodeToExpr))
case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
UnresolvedFunction(name, Star(None) :: Nil)
UnresolvedFunction(name, UnresolvedStar(None) :: Nil)

/* Literals */
case Token("TOK_NULL", Nil) => Literal(null, NullType)
Expand Down

0 comments on commit 80def9d

Please sign in to comment.