Skip to content
This repository has been archived by the owner on Nov 30, 2019. It is now read-only.

Commit

Permalink
[SQL] Various DataFrame DSL update.
Browse files Browse the repository at this point in the history
1. Added foreach, foreachPartition, flatMap to DataFrame.
2. Added col() in dsl.
3. Support renaming columns in toDataFrame.
4. Support type inference on arrays (in addition to Seq).
  • Loading branch information
rxin committed Jan 29, 2015
1 parent 5b9760d commit 62608c4
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 37 deletions.
3 changes: 1 addition & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
dataset.select($"*", callUDF(
this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -138,10 +137,10 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-margin))
}
val t = map(threshold)
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
.select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
val predict: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
dataset.select(
$"*",
callUDF(score, dataset(map(featuresCol))).as(map(scoreCol)),
callUDF(predict, dataset(map(scoreCol))).as(map(predictionCol)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}

/**
Expand Down Expand Up @@ -85,7 +84,7 @@ class StandardScalerModel private[ml] (
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
Expand Down
35 changes: 11 additions & 24 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,10 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext._
import org.apache.spark.ml.recommendation.ALSModel.Factor
import dataset.sqlContext.createDataFrame
val map = this.paramMap ++ paramMap
// TODO: Add DSL to simplify the code here.
val instanceTable = s"instance_$uid"
val userTable = s"user_$uid"
val itemTable = s"item_$uid"
val instances = dataset.as(instanceTable)
val users = userFactors.map { case (id, features) =>
Factor(id, features)
}.as(userTable)
val items = itemFactors.map { case (id, features) =>
Factor(id, features)
}.as(itemTable)
val users = userFactors.toDataFrame("id", "features")
val items = itemFactors.toDataFrame("id", "features")
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
Expand All @@ -133,24 +123,21 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
.as(map(predictionCol))
val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
instances
.join(users, Column(map(userCol)) === $"$userTable.id", "left")
.join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
dataset
.join(users, dataset(map(userCol)) === users("id"), "left")
.join(items, dataset(map(itemCol)) === items("id"), "left")
.select(outputColumns: _*)
// TODO: Just use a dataset("*")
//.select(dataset("*"), prediction)
}

override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

private object ALSModel {
/** Case class to convert factors to [[DataFrame]]s */
private case class Factor(id: Int, features: Seq[Float])
}

/**
* Alternating Least Squares (ALS) matrix factorization.
Expand Down Expand Up @@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
val ratings = dataset
.select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}

import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (s: Array[_], arrayType: ArrayType) => s.toSeq
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
Expand Down Expand Up @@ -140,7 +141,9 @@ trait ScalaReflection {
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ case class OptionalData(

case class ComplexData(
arrayField: Seq[Int],
arrayField1: Array[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
Expand Down Expand Up @@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite {
"arrayField",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayField1",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
Expand Down
41 changes: 40 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,29 @@ class DataFrame protected[sql](
*/
def toDataFrame: DataFrame = this

/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
* rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
* rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
*/
@scala.annotation.varargs
def toDataFrame(colName: String, colNames: String*): DataFrame = {
val newNames = colName +: colNames
require(schema.size == newNames.size,
"The number of columns doesn't match.\n" +
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
"New column names: " + newNames.mkString(", "))

val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
apply(oldName).as(newName)
}
select(newCols :_*)
}

/** Returns the schema of this [[DataFrame]]. */
override def schema: StructType = queryExecution.analyzed.schema

Expand Down Expand Up @@ -466,13 +489,29 @@ class DataFrame protected[sql](
rdd.map(f)
}

/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
*/
override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)

/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
rdd.mapPartitions(f)
}

/**
* Applies a function `f` to all rows.
*/
override def foreach(f: Row => Unit): Unit = rdd.foreach(f)

/**
* Applies a function f to each partition of this [[DataFrame]].
*/
override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)

/**
* Returns the first `n` rows in the [[DataFrame]].
*/
Expand Down Expand Up @@ -520,7 +559,7 @@ class DataFrame protected[sql](
/////////////////////////////////////////////////////////////////////////////

/**
* Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
*/
override def rdd: RDD[Row] = {
val schema = this.schema
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {

def map[R: ClassTag](f: T => R): RDD[R]

def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]

def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]

def foreach(f: T => Unit): Unit

def foreachPartition(f: Iterator[T] => Unit): Unit

def take(n: Int): Array[T]

def collect(): Array[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ public class dsl {

private static package$ scalaDsl = package$.MODULE$;

/**
* Returns a {@link Column} based on the given column name.
*/
public static Column col(String colName) {
return new Column(colName);
}

/**
* Creates a column of literal value.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -37,6 +38,21 @@ package object dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

// /**
// * An implicit conversion that turns a RDD of product into a [[DataFrame]].
// *
// * This method requires an implicit SQLContext in scope. For example:
// * {{{
// * implicit val sqlContext: SQLContext = ...
// * val rdd: RDD[(Int, String)] = ...
// * rdd.toDataFrame // triggers the implicit here
// * }}}
// */
// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
// : DataFrame = {
// context.createDataFrame(rdd)
// }

/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
Expand All @@ -46,6 +62,11 @@ package object dsl {

private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)

/**
* Returns a [[Column]] based on the given column name.
*/
def col(colName: String): Column = new Column(colName)

/**
* Creates a [[Column]] of literal value.
*/
Expand Down

0 comments on commit 62608c4

Please sign in to comment.